rateslib/dual/dual_ops/
eq.rs

1use crate::dual::dual::{Dual, Dual2, Vars, VarsRelationship};
2use crate::dual::enums::Number;
3
4/// Measures value equivalence of `Dual`.
5///
6/// Returns `true` if:
7///
8/// - `real` components are equal: `lhs.real == rhs.real`.
9/// - `dual` components are equal after aligning `vars`.
10impl PartialEq<Dual> for Dual {
11    fn eq(&self, other: &Dual) -> bool {
12        if self.real != other.real {
13            false
14        } else {
15            let state = self.vars_cmp(other.vars());
16            match state {
17                VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
18                    self.dual.iter().eq(other.dual.iter())
19                }
20                _ => {
21                    let (x, y) = self.to_union_vars(other, Some(state));
22                    x.dual.iter().eq(y.dual.iter())
23                }
24            }
25        }
26    }
27}
28
29impl PartialEq<f64> for Dual {
30    fn eq(&self, other: &f64) -> bool {
31        Dual::new(*other, [].to_vec()) == *self
32    }
33}
34
35impl PartialEq<f64> for Dual2 {
36    fn eq(&self, other: &f64) -> bool {
37        Dual2::new(*other, Vec::new()) == *self
38    }
39}
40
41impl PartialEq<Dual> for f64 {
42    fn eq(&self, other: &Dual) -> bool {
43        Dual::new(*self, [].to_vec()) == *other
44    }
45}
46
47impl PartialEq<Dual2> for f64 {
48    fn eq(&self, other: &Dual2) -> bool {
49        Dual2::new(*self, Vec::new()) == *other
50    }
51}
52
53impl PartialEq<Dual2> for Dual2 {
54    fn eq(&self, other: &Dual2) -> bool {
55        if self.real != other.real {
56            false
57        } else {
58            let state = self.vars_cmp(other.vars());
59            match state {
60                VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
61                    self.dual.iter().eq(other.dual.iter())
62                        && self.dual2.iter().eq(other.dual2.iter())
63                }
64                _ => {
65                    let (x, y) = self.to_union_vars(other, Some(state));
66                    x.dual.iter().eq(y.dual.iter()) && x.dual2.iter().eq(y.dual2.iter())
67                }
68            }
69        }
70    }
71}
72
73impl PartialEq<Number> for Number {
74    fn eq(&self, other: &Number) -> bool {
75        match (self, other) {
76            (Number::F64(f), Number::F64(f2)) => f == f2,
77            (Number::F64(f), Number::Dual(d2)) => f == d2,
78            (Number::F64(f), Number::Dual2(d2)) => f == d2,
79            (Number::Dual(d), Number::F64(f2)) => d == f2,
80            (Number::Dual(d), Number::Dual(d2)) => d == d2,
81            (Number::Dual(_), Number::Dual2(_)) => {
82                panic!("Cannot mix dual types: Dual == Dual2")
83            }
84            (Number::Dual2(d), Number::F64(f2)) => d == f2,
85            (Number::Dual2(_), Number::Dual(_)) => {
86                panic!("Cannot mix dual types: Dual2 == Dual")
87            }
88            (Number::Dual2(d), Number::Dual2(d2)) => d == d2,
89        }
90    }
91}
92
93impl PartialEq<f64> for Number {
94    fn eq(&self, other: &f64) -> bool {
95        match self {
96            Number::F64(f) => f == other,
97            Number::Dual(d) => d == other,
98            Number::Dual2(d) => d == other,
99        }
100    }
101}
102
103impl PartialEq<Number> for f64 {
104    fn eq(&self, other: &Number) -> bool {
105        match other {
106            Number::F64(f) => f == self,
107            Number::Dual(d) => d == self,
108            Number::Dual2(d) => d == self,
109        }
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn eq_ne() {
119        // Dual with vars - f64
120        assert!(Dual::new(0.0, Vec::from([String::from("a")])) != 0.0);
121        // Dual with no vars - f64 (+reverse)
122        assert!(Dual::new(2.0, Vec::new()) == 2.0);
123        assert!(2.0 == Dual::new(2.0, Vec::new()));
124        // Dual - Dual (various real, vars, gradient mismatch)
125        let d = Dual::try_new(2.0, Vec::from([String::from("a")]), Vec::from([2.3])).unwrap();
126        assert!(d == Dual::try_new(2.0, Vec::from([String::from("a")]), Vec::from([2.3])).unwrap());
127        assert!(d != Dual::try_new(2.0, Vec::from([String::from("b")]), Vec::from([2.3])).unwrap());
128        assert!(d != Dual::try_new(3.0, Vec::from([String::from("a")]), Vec::from([2.3])).unwrap());
129        assert!(d != Dual::try_new(2.0, Vec::from([String::from("a")]), Vec::from([1.3])).unwrap());
130        // Dual - Dual (missing Vars are zero and upcasted)
131        assert!(
132            d == Dual::try_new(
133                2.0,
134                Vec::from([String::from("a"), String::from("b")]),
135                Vec::from([2.3, 0.0])
136            )
137            .unwrap()
138        );
139    }
140
141    #[test]
142    fn eq_ne2() {
143        // Dual with vars - f64
144        assert!(Dual2::new(0.0, Vec::from([String::from("a")])) != 0.0);
145        // Dual with no vars - f64 (+reverse)
146        assert!(Dual2::new(2.0, Vec::new()) == 2.0);
147        assert!(2.0 == Dual2::new(2.0, Vec::new()));
148        // Dual - Dual (various real, vars, gradient mismatch)
149        let d = Dual2::try_new(
150            2.0,
151            Vec::from([String::from("a")]),
152            Vec::from([2.3]),
153            Vec::new(),
154        )
155        .unwrap();
156        assert!(
157            d == Dual2::try_new(
158                2.0,
159                Vec::from([String::from("a")]),
160                Vec::from([2.3]),
161                Vec::new()
162            )
163            .unwrap()
164        );
165        assert!(
166            d != Dual2::try_new(
167                2.0,
168                Vec::from([String::from("b")]),
169                Vec::from([2.3]),
170                Vec::new()
171            )
172            .unwrap()
173        );
174        assert!(
175            d != Dual2::try_new(
176                3.0,
177                Vec::from([String::from("a")]),
178                Vec::from([2.3]),
179                Vec::new()
180            )
181            .unwrap()
182        );
183        assert!(
184            d != Dual2::try_new(
185                2.0,
186                Vec::from([String::from("a")]),
187                Vec::from([1.3]),
188                Vec::new()
189            )
190            .unwrap()
191        );
192        // Dual - Dual (missing Vars are zero and upcasted)
193        assert!(
194            d == Dual2::try_new(
195                2.0,
196                Vec::from([String::from("a"), String::from("b")]),
197                Vec::from([2.3, 0.0]),
198                Vec::new()
199            )
200            .unwrap()
201        );
202    }
203
204    #[test]
205    fn test_enum_ne() {
206        let d = Number::Dual(Dual::new(2.0, vec!["x".to_string()]));
207        let d2 = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
208        assert_ne!(d, d2)
209    }
210
211    #[test]
212    fn test_enum() {
213        let d = Number::Dual(Dual::new(2.0, vec!["x".to_string()]));
214        let d2 = Number::Dual(Dual::new(2.0, vec!["x".to_string()]));
215        assert_eq!(d, d2)
216    }
217
218    #[test]
219    fn test_cross_enum_eq() {
220        let f = Number::F64(2.5_f64);
221        let d = Number::Dual(Dual::new(2.5_f64, vec![]));
222        assert_eq!(f, d);
223    }
224
225    #[test]
226    #[should_panic]
227    fn test_cross_enum_eq_error() {
228        let d2 = Number::Dual2(Dual2::new(2.5_f64, vec![]));
229        let d = Number::Dual(Dual::new(2.5_f64, vec![]));
230        assert_eq!(d2, d);
231    }
232
233    #[test]
234    fn test_enum_f64() {
235        let n = Number::Dual(Dual::new(2.0, vec![]));
236        assert!(n == 2.0);
237        assert!(2.0 == 2.0);
238    }
239}