rateslib/dual/dual_ops/
mul.rs

1use crate::dual::dual::{Dual, Dual2, Vars, VarsRelationship};
2use crate::dual::enums::Number;
3use crate::dual::linalg::fouter11_;
4use auto_ops::{impl_op_ex, impl_op_ex_commutative};
5use ndarray::Array2;
6use std::sync::Arc;
7
8// Mul
9impl_op_ex_commutative!(*|a: &Dual, b: &f64| -> Dual {
10    Dual {
11        vars: Arc::clone(&a.vars),
12        real: a.real * b,
13        dual: *b * &a.dual,
14    }
15});
16impl_op_ex_commutative!(*|a: &Dual2, b: &f64| -> Dual2 {
17    Dual2 {
18        vars: Arc::clone(&a.vars),
19        real: a.real * b,
20        dual: *b * &a.dual,
21        dual2: *b * &a.dual2,
22    }
23});
24
25// impl Mul for Dual
26impl_op_ex!(*|a: &Dual, b: &Dual| -> Dual {
27    let state = a.vars_cmp(b.vars());
28    match state {
29        VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => Dual {
30            real: a.real * b.real,
31            dual: &a.dual * b.real + &b.dual * a.real,
32            vars: Arc::clone(&a.vars),
33        },
34        _ => {
35            let (x, y) = a.to_union_vars(b, Some(state));
36            Dual {
37                real: x.real * y.real,
38                dual: &x.dual * y.real + &y.dual * x.real,
39                vars: Arc::clone(&x.vars),
40            }
41        }
42    }
43});
44
45// impl Mul for Dual2
46impl_op_ex!(*|a: &Dual2, b: &Dual2| -> Dual2 {
47    let state = a.vars_cmp(b.vars());
48    match state {
49        VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
50            let mut dual2: Array2<f64> = &a.dual2 * b.real + &b.dual2 * a.real;
51            let cross_beta = fouter11_(&a.dual.view(), &b.dual.view());
52            dual2 = dual2 + 0.5_f64 * (&cross_beta + &cross_beta.t());
53            Dual2 {
54                real: a.real * b.real,
55                dual: &a.dual * b.real + &b.dual * a.real,
56                vars: Arc::clone(&a.vars),
57                dual2,
58            }
59        }
60        _ => {
61            let (x, y) = a.to_union_vars(b, Some(state));
62            let mut dual2: Array2<f64> = &x.dual2 * y.real + &y.dual2 * x.real;
63            let cross_beta = fouter11_(&x.dual.view(), &y.dual.view());
64            dual2 = dual2 + 0.5_f64 * (&cross_beta + &cross_beta.t());
65            Dual2 {
66                real: x.real * y.real,
67                dual: &x.dual * y.real + &y.dual * x.real,
68                vars: Arc::clone(&x.vars),
69                dual2,
70            }
71        }
72    }
73});
74
75// Mul for Number
76impl_op_ex!(*|a: &Number, b: &Number| -> Number {
77    match (a, b) {
78        (Number::F64(f), Number::F64(f2)) => Number::F64(f * f2),
79        (Number::F64(f), Number::Dual(d2)) => Number::Dual(f * d2),
80        (Number::F64(f), Number::Dual2(d2)) => Number::Dual2(f * d2),
81        (Number::Dual(d), Number::F64(f2)) => Number::Dual(d * f2),
82        (Number::Dual(d), Number::Dual(d2)) => Number::Dual(d * d2),
83        (Number::Dual(_), Number::Dual2(_)) => {
84            panic!("Cannot mix dual types: Dual * Dual2")
85        }
86        (Number::Dual2(d), Number::F64(f2)) => Number::Dual2(d * f2),
87        (Number::Dual2(_), Number::Dual(_)) => {
88            panic!("Cannot mix dual types: Dual2 * Dual")
89        }
90        (Number::Dual2(d), Number::Dual2(d2)) => Number::Dual2(d * d2),
91    }
92});
93
94// Mul for Number
95impl_op_ex_commutative!(*|a: &Number, b: &f64| -> Number {
96    match a {
97        Number::F64(f) => Number::F64(f * b),
98        Number::Dual(d) => Number::Dual(d * b),
99        Number::Dual2(d) => Number::Dual2(d * b),
100    }
101});
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn mul_f64() {
109        let d1 = Dual::try_new(
110            1.0,
111            vec!["v0".to_string(), "v1".to_string()],
112            vec![1.0, 2.0],
113        )
114        .unwrap();
115        let result = 10.0 * d1 * 2.0;
116        let expected = Dual::try_new(
117            20.0,
118            vec!["v0".to_string(), "v1".to_string()],
119            vec![20.0, 40.0],
120        )
121        .unwrap();
122        assert_eq!(result, expected)
123    }
124
125    #[test]
126    fn mul() {
127        let d1 = Dual::try_new(
128            1.0,
129            vec!["v0".to_string(), "v1".to_string()],
130            vec![1.0, 2.0],
131        )
132        .unwrap();
133        let d2 = Dual::try_new(
134            2.0,
135            vec!["v0".to_string(), "v2".to_string()],
136            vec![0.0, 3.0],
137        )
138        .unwrap();
139        let expected = Dual::try_new(
140            2.0,
141            vec!["v0".to_string(), "v1".to_string(), "v2".to_string()],
142            vec![2.0, 4.0, 3.0],
143        )
144        .unwrap();
145        let result = d1 * d2;
146        assert_eq!(result, expected)
147    }
148
149    #[test]
150    fn mul_f64_2() {
151        let d1 = Dual2::try_new(
152            1.0,
153            vec!["v0".to_string(), "v1".to_string()],
154            vec![1.0, 2.0],
155            Vec::new(),
156        )
157        .unwrap();
158        let result = 10.0 * d1 * 2.0;
159        let expected = Dual2::try_new(
160            20.0,
161            vec!["v0".to_string(), "v1".to_string()],
162            vec![20.0, 40.0],
163            Vec::new(),
164        )
165        .unwrap();
166        assert_eq!(result, expected)
167    }
168
169    #[test]
170    fn mul2() {
171        let d1 = Dual2::try_new(
172            1.0,
173            vec!["v0".to_string(), "v1".to_string()],
174            vec![1.0, 2.0],
175            Vec::new(),
176        )
177        .unwrap();
178        let d2 = Dual2::try_new(
179            2.0,
180            vec!["v0".to_string(), "v2".to_string()],
181            vec![0.0, 3.0],
182            Vec::new(),
183        )
184        .unwrap();
185        let expected = Dual2::try_new(
186            2.0,
187            vec!["v0".to_string(), "v1".to_string(), "v2".to_string()],
188            vec![2.0, 4.0, 3.0],
189            vec![0., 0., 1.5, 0., 0., 3., 1.5, 3., 0.],
190        )
191        .unwrap();
192        let result = d1 * d2;
193        assert_eq!(result, expected)
194    }
195
196    #[test]
197    fn test_enum() {
198        let f = Number::F64(2.0);
199        let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
200        assert_eq!(
201            &f * &d,
202            Number::Dual(Dual::try_new(6.0, vec!["x".to_string()], vec![2.0]).unwrap())
203        );
204
205        assert_eq!(
206            &d * &d,
207            Number::Dual(Dual::try_new(9.0, vec!["x".to_string()], vec![6.0]).unwrap())
208        );
209    }
210
211    #[test]
212    #[should_panic]
213    fn test_enum_panic() {
214        let d = Number::Dual2(Dual2::new(2.0, vec!["y".to_string()]));
215        let d2 = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
216        let _ = d * d2;
217    }
218
219    #[test]
220    fn test_enum_f64() {
221        let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
222        let res = 2.0_f64 * d;
223        assert_eq!(
224            res,
225            Number::Dual(Dual::new(3.0, vec!["x".to_string()]) * 2.0)
226        );
227    }
228}