rateslib/dual/dual_ops/
sub.rs

1use crate::dual::dual::{Dual, Dual2, Vars, VarsRelationship};
2use crate::dual::enums::Number;
3use auto_ops::impl_op_ex;
4use std::sync::Arc;
5
6// Sub
7impl_op_ex!(-|a: &Dual, b: &f64| -> Dual {
8    Dual {
9        vars: Arc::clone(&a.vars),
10        real: a.real - b,
11        dual: a.dual.clone(),
12    }
13});
14impl_op_ex!(-|a: &f64, b: &Dual| -> Dual {
15    Dual {
16        vars: Arc::clone(&b.vars),
17        real: a - b.real,
18        dual: -(b.dual.clone()),
19    }
20});
21impl_op_ex!(-|a: &Dual2, b: &f64| -> Dual2 {
22    Dual2 {
23        vars: Arc::clone(&a.vars),
24        real: a.real - b,
25        dual: a.dual.clone(),
26        dual2: a.dual2.clone(),
27    }
28});
29impl_op_ex!(-|a: &f64, b: &Dual2| -> Dual2 {
30    Dual2 {
31        vars: Arc::clone(&b.vars),
32        real: a - b.real,
33        dual: -(b.dual.clone()),
34        dual2: -(b.dual2.clone()),
35    }
36});
37
38// impl Sub for Dual
39impl_op_ex!(-|a: &Dual, b: &Dual| -> Dual {
40    let state = a.vars_cmp(b.vars());
41    match state {
42        VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => Dual {
43            real: a.real - b.real,
44            dual: &a.dual - &b.dual,
45            vars: Arc::clone(&a.vars),
46        },
47        _ => {
48            let (x, y) = a.to_union_vars(b, Some(state));
49            Dual {
50                real: x.real - y.real,
51                dual: &x.dual - &y.dual,
52                vars: Arc::clone(&x.vars),
53            }
54        }
55    }
56});
57
58// impl Sub
59impl_op_ex!(-|a: &Dual2, b: &Dual2| -> Dual2 {
60    let state = a.vars_cmp(b.vars());
61    match state {
62        VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => Dual2 {
63            real: a.real - b.real,
64            dual: &a.dual - &b.dual,
65            dual2: &a.dual2 - &b.dual2,
66            vars: Arc::clone(&a.vars),
67        },
68        _ => {
69            let (x, y) = a.to_union_vars(b, Some(state));
70            Dual2 {
71                real: x.real - y.real,
72                dual: &x.dual - &y.dual,
73                dual2: &x.dual2 - &y.dual2,
74                vars: Arc::clone(&x.vars),
75            }
76        }
77    }
78});
79
80// Sub for Number
81impl_op_ex!(-|a: &Number, b: &Number| -> Number {
82    match (a, b) {
83        (Number::F64(f), Number::F64(f2)) => Number::F64(f - f2),
84        (Number::F64(f), Number::Dual(d2)) => Number::Dual(f - d2),
85        (Number::F64(f), Number::Dual2(d2)) => Number::Dual2(f - d2),
86        (Number::Dual(d), Number::F64(f2)) => Number::Dual(d - f2),
87        (Number::Dual(d), Number::Dual(d2)) => Number::Dual(d - d2),
88        (Number::Dual(_), Number::Dual2(_)) => {
89            panic!("Cannot mix dual types: Dual - Dual2")
90        }
91        (Number::Dual2(d), Number::F64(f2)) => Number::Dual2(d - f2),
92        (Number::Dual2(_), Number::Dual(_)) => {
93            panic!("Cannot mix dual types: Dual2 - Dual")
94        }
95        (Number::Dual2(d), Number::Dual2(d2)) => Number::Dual2(d - d2),
96    }
97});
98
99// Sub for Number
100impl_op_ex!(-|a: &Number, b: &f64| -> Number {
101    match a {
102        Number::F64(f) => Number::F64(f - b),
103        Number::Dual(d) => Number::Dual(d - b),
104        Number::Dual2(d) => Number::Dual2(d - b),
105    }
106});
107
108// Sub for Number
109impl_op_ex!(-|a: &f64, b: &Number| -> Number {
110    match b {
111        Number::F64(f) => Number::F64(a - f),
112        Number::Dual(d) => Number::Dual(a - d),
113        Number::Dual2(d) => Number::Dual2(a - d),
114    }
115});
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn sub_f64() {
123        let d1 = Dual::try_new(
124            1.0,
125            vec!["v0".to_string(), "v1".to_string()],
126            vec![1.0, 2.0],
127        )
128        .unwrap();
129        let result = (10.0 - d1) - 15.0;
130        let expected = Dual::try_new(
131            -6.0,
132            vec!["v0".to_string(), "v1".to_string()],
133            vec![-1.0, -2.0],
134        )
135        .unwrap();
136        assert_eq!(result, expected)
137    }
138
139    #[test]
140    fn sub() {
141        let d1 = Dual::try_new(
142            1.0,
143            vec!["v0".to_string(), "v1".to_string()],
144            vec![1.0, 2.0],
145        )
146        .unwrap();
147        let d2 = Dual::try_new(
148            2.0,
149            vec!["v0".to_string(), "v2".to_string()],
150            vec![0.0, 3.0],
151        )
152        .unwrap();
153        let expected = Dual::try_new(
154            -1.0,
155            vec!["v0".to_string(), "v1".to_string(), "v2".to_string()],
156            vec![1.0, 2.0, -3.0],
157        )
158        .unwrap();
159        let result = d1 - d2;
160        assert_eq!(result, expected)
161    }
162    #[test]
163    fn sub_f64_2() {
164        let d1 = Dual2::try_new(
165            1.0,
166            vec!["v0".to_string(), "v1".to_string()],
167            vec![1.0, 2.0],
168            Vec::new(),
169        )
170        .unwrap();
171        let result = (10.0 - d1) - 15.0;
172        let expected = Dual2::try_new(
173            -6.0,
174            vec!["v0".to_string(), "v1".to_string()],
175            vec![-1.0, -2.0],
176            Vec::new(),
177        )
178        .unwrap();
179        assert_eq!(result, expected)
180    }
181
182    #[test]
183    fn sub2() {
184        let d1 = Dual2::try_new(
185            1.0,
186            vec!["v0".to_string(), "v1".to_string()],
187            vec![1.0, 2.0],
188            Vec::new(),
189        )
190        .unwrap();
191        let d2 = Dual2::try_new(
192            2.0,
193            vec!["v0".to_string(), "v2".to_string()],
194            vec![0.0, 3.0],
195            Vec::new(),
196        )
197        .unwrap();
198        let expected = Dual2::try_new(
199            -1.0,
200            vec!["v0".to_string(), "v1".to_string(), "v2".to_string()],
201            vec![1.0, 2.0, -3.0],
202            Vec::new(),
203        )
204        .unwrap();
205        let result = d1 - d2;
206        assert_eq!(result, expected)
207    }
208
209    #[test]
210    fn test_enum() {
211        let f = Number::F64(2.0);
212        let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
213        assert_eq!(
214            &f - &d,
215            Number::Dual(Dual::try_new(-1.0, vec!["x".to_string()], vec![-1.0]).unwrap())
216        );
217
218        assert_eq!(
219            &d - &d,
220            Number::Dual(Dual::try_new(0.0, vec!["x".to_string()], vec![0.0]).unwrap())
221        );
222    }
223
224    #[test]
225    #[should_panic]
226    fn test_enum_panic() {
227        let d = Number::Dual2(Dual2::new(2.0, vec!["y".to_string()]));
228        let d2 = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
229        let _ = d - d2;
230    }
231
232    #[test]
233    fn test_enum_f64() {
234        let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
235        let res = 2.5_f64 - d;
236        assert_eq!(
237            res,
238            Number::Dual(Dual::new(0.5, vec!["x".to_string()]) * -1.0)
239        );
240
241        let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
242        let res = d - 2.5_f64;
243        assert_eq!(res, Number::Dual(Dual::new(0.5, vec!["x".to_string()])));
244    }
245}