rateslib/dual/dual_ops/
add.rs

1use crate::dual::dual::{Dual, Dual2, Vars, VarsRelationship};
2use crate::dual::enums::Number;
3use auto_ops::{impl_op_ex, impl_op_ex_commutative};
4use std::sync::Arc;
5
6// Add f64
7impl_op_ex_commutative!(+ |a: &Dual, b: &f64| -> Dual { Dual {vars: Arc::clone(&a.vars), real: a.real + b, dual: a.dual.clone()} });
8impl_op_ex_commutative!(+ |a: &Dual2, b: &f64| -> Dual2 {
9    Dual2 {vars: Arc::clone(&a.vars), real: a.real + b, dual: a.dual.clone(), dual2: a.dual2.clone()}
10});
11
12// Add for Dual
13impl_op_ex!(+ |a: &Dual, b: &Dual| -> Dual {
14    let state = a.vars_cmp(b.vars());
15    match state {
16        VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
17            Dual {real: a.real + b.real, dual: &a.dual + &b.dual, vars: Arc::clone(&a.vars)}
18        }
19        _ => {
20            let (x, y) = a.to_union_vars(b, Some(state));
21            Dual {real: x.real + y.real, dual: &x.dual + &y.dual, vars: Arc::clone(&x.vars)}
22        }
23    }
24});
25
26// Add for Dual2
27impl_op_ex!(+ |a: &Dual2, b: &Dual2| -> Dual2 {
28    let state = a.vars_cmp(b.vars());
29    match state {
30        VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
31            Dual2 {
32                real: a.real + b.real,
33                dual: &a.dual + &b.dual,
34                dual2: &a.dual2 + &b.dual2,
35                vars: Arc::clone(&a.vars)}
36        }
37        _ => {
38            let (x, y) = a.to_union_vars(b, Some(state));
39            Dual2 {
40                real: x.real + y.real,
41                dual: &x.dual + &y.dual,
42                dual2: &x.dual2 + &y.dual2,
43                vars: Arc::clone(&x.vars)}
44        }
45    }
46});
47
48// Add for Number
49impl_op_ex!(+ |a: &Number, b: &Number| -> Number {
50    match (a,b) {
51        (Number::F64(f), Number::F64(f2)) => Number::F64(f + f2),
52        (Number::F64(f), Number::Dual(d2)) => Number::Dual(f + d2),
53        (Number::F64(f), Number::Dual2(d2)) => Number::Dual2(f + d2),
54        (Number::Dual(d), Number::F64(f2)) => Number::Dual(d + f2),
55        (Number::Dual(d), Number::Dual(d2)) => Number::Dual(d + d2),
56        (Number::Dual(_), Number::Dual2(_)) => panic!("Cannot mix dual types: Dual + Dual2"),
57        (Number::Dual2(d), Number::F64(f2)) => Number::Dual2(d + f2),
58        (Number::Dual2(_), Number::Dual(_)) => panic!("Cannot mix dual types: Dual2 + Dual"),
59        (Number::Dual2(d), Number::Dual2(d2)) => Number::Dual2(d + d2),
60    }
61});
62
63// Add for Number
64impl_op_ex_commutative!(+ |a: &Number, b: &f64| -> Number {
65    match a {
66        Number::F64(f) => Number::F64(f + b),
67        Number::Dual(d) => Number::Dual(d + b),
68        Number::Dual2(d) => Number::Dual2(d + b),
69    }
70});
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn add_f64() {
78        let d1 = Dual::try_new(
79            1.0,
80            vec!["v0".to_string(), "v1".to_string()],
81            vec![1.0, 2.0],
82        )
83        .unwrap();
84        let result = 10.0 + d1 + 15.0;
85        let expected = Dual::try_new(
86            26.0,
87            vec!["v0".to_string(), "v1".to_string()],
88            vec![1.0, 2.0],
89        )
90        .unwrap();
91        assert_eq!(result, expected)
92    }
93
94    #[test]
95    fn add() {
96        let d1 = Dual::try_new(
97            1.0,
98            vec!["v0".to_string(), "v1".to_string()],
99            vec![1.0, 2.0],
100        )
101        .unwrap();
102        let d2 = Dual::try_new(
103            2.0,
104            vec!["v0".to_string(), "v2".to_string()],
105            vec![0.0, 3.0],
106        )
107        .unwrap();
108        let expected = Dual::try_new(
109            3.0,
110            vec!["v0".to_string(), "v1".to_string(), "v2".to_string()],
111            vec![1.0, 2.0, 3.0],
112        )
113        .unwrap();
114        let result = d1 + d2;
115        assert_eq!(result, expected)
116    }
117
118    #[test]
119    fn add_f64_2() {
120        let d1 = Dual2::try_new(
121            1.0,
122            vec!["v0".to_string(), "v1".to_string()],
123            vec![1.0, 2.0],
124            Vec::new(),
125        )
126        .unwrap();
127        let result = 10.0 + d1 + 15.0;
128        let expected = Dual2::try_new(
129            26.0,
130            vec!["v0".to_string(), "v1".to_string()],
131            vec![1.0, 2.0],
132            Vec::new(),
133        )
134        .unwrap();
135        assert_eq!(result, expected)
136    }
137
138    #[test]
139    fn add2() {
140        let d1 = Dual2::try_new(
141            1.0,
142            vec!["v0".to_string(), "v1".to_string()],
143            vec![1.0, 2.0],
144            Vec::new(),
145        )
146        .unwrap();
147        let d2 = Dual2::try_new(
148            2.0,
149            vec!["v0".to_string(), "v2".to_string()],
150            vec![0.0, 3.0],
151            Vec::new(),
152        )
153        .unwrap();
154        let expected = Dual2::try_new(
155            3.0,
156            vec!["v0".to_string(), "v1".to_string(), "v2".to_string()],
157            vec![1.0, 2.0, 3.0],
158            Vec::new(),
159        )
160        .unwrap();
161        let result = d1 + d2;
162        assert_eq!(result, expected)
163    }
164
165    #[test]
166    fn test_enum() {
167        let f = Number::F64(2.0);
168        let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
169        assert_eq!(&f + &d, Number::Dual(Dual::new(5.0, vec!["x".to_string()])));
170
171        assert_eq!(
172            &d + &d,
173            Number::Dual(Dual::try_new(6.0, vec!["x".to_string()], vec![2.0]).unwrap())
174        );
175    }
176
177    #[test]
178    #[should_panic]
179    fn test_enum_panic() {
180        let d = Number::Dual2(Dual2::new(2.0, vec!["y".to_string()]));
181        let d2 = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
182        let _ = d + d2;
183    }
184
185    #[test]
186    fn test_enum_f64() {
187        let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
188        let res = 2.5_f64 + d;
189        assert_eq!(res, Number::Dual(Dual::new(5.5, vec!["x".to_string()])));
190    }
191}