1use crate::dual::dual::{Dual, Dual2, Vars, VarsRelationship};
2use crate::dual::enums::Number;
3use crate::dual::linalg::fouter11_;
4use auto_ops::{impl_op_ex, impl_op_ex_commutative};
5use ndarray::Array2;
6use std::sync::Arc;
7
8impl_op_ex_commutative!(*|a: &Dual, b: &f64| -> Dual {
10 Dual {
11 vars: Arc::clone(&a.vars),
12 real: a.real * b,
13 dual: *b * &a.dual,
14 }
15});
16impl_op_ex_commutative!(*|a: &Dual2, b: &f64| -> Dual2 {
17 Dual2 {
18 vars: Arc::clone(&a.vars),
19 real: a.real * b,
20 dual: *b * &a.dual,
21 dual2: *b * &a.dual2,
22 }
23});
24
25impl_op_ex!(*|a: &Dual, b: &Dual| -> Dual {
27 let state = a.vars_cmp(b.vars());
28 match state {
29 VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => Dual {
30 real: a.real * b.real,
31 dual: &a.dual * b.real + &b.dual * a.real,
32 vars: Arc::clone(&a.vars),
33 },
34 _ => {
35 let (x, y) = a.to_union_vars(b, Some(state));
36 Dual {
37 real: x.real * y.real,
38 dual: &x.dual * y.real + &y.dual * x.real,
39 vars: Arc::clone(&x.vars),
40 }
41 }
42 }
43});
44
45impl_op_ex!(*|a: &Dual2, b: &Dual2| -> Dual2 {
47 let state = a.vars_cmp(b.vars());
48 match state {
49 VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
50 let mut dual2: Array2<f64> = &a.dual2 * b.real + &b.dual2 * a.real;
51 let cross_beta = fouter11_(&a.dual.view(), &b.dual.view());
52 dual2 = dual2 + 0.5_f64 * (&cross_beta + &cross_beta.t());
53 Dual2 {
54 real: a.real * b.real,
55 dual: &a.dual * b.real + &b.dual * a.real,
56 vars: Arc::clone(&a.vars),
57 dual2,
58 }
59 }
60 _ => {
61 let (x, y) = a.to_union_vars(b, Some(state));
62 let mut dual2: Array2<f64> = &x.dual2 * y.real + &y.dual2 * x.real;
63 let cross_beta = fouter11_(&x.dual.view(), &y.dual.view());
64 dual2 = dual2 + 0.5_f64 * (&cross_beta + &cross_beta.t());
65 Dual2 {
66 real: x.real * y.real,
67 dual: &x.dual * y.real + &y.dual * x.real,
68 vars: Arc::clone(&x.vars),
69 dual2,
70 }
71 }
72 }
73});
74
75impl_op_ex!(*|a: &Number, b: &Number| -> Number {
77 match (a, b) {
78 (Number::F64(f), Number::F64(f2)) => Number::F64(f * f2),
79 (Number::F64(f), Number::Dual(d2)) => Number::Dual(f * d2),
80 (Number::F64(f), Number::Dual2(d2)) => Number::Dual2(f * d2),
81 (Number::Dual(d), Number::F64(f2)) => Number::Dual(d * f2),
82 (Number::Dual(d), Number::Dual(d2)) => Number::Dual(d * d2),
83 (Number::Dual(_), Number::Dual2(_)) => {
84 panic!("Cannot mix dual types: Dual * Dual2")
85 }
86 (Number::Dual2(d), Number::F64(f2)) => Number::Dual2(d * f2),
87 (Number::Dual2(_), Number::Dual(_)) => {
88 panic!("Cannot mix dual types: Dual2 * Dual")
89 }
90 (Number::Dual2(d), Number::Dual2(d2)) => Number::Dual2(d * d2),
91 }
92});
93
94impl_op_ex_commutative!(*|a: &Number, b: &f64| -> Number {
96 match a {
97 Number::F64(f) => Number::F64(f * b),
98 Number::Dual(d) => Number::Dual(d * b),
99 Number::Dual2(d) => Number::Dual2(d * b),
100 }
101});
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn mul_f64() {
109 let d1 = Dual::try_new(
110 1.0,
111 vec!["v0".to_string(), "v1".to_string()],
112 vec![1.0, 2.0],
113 )
114 .unwrap();
115 let result = 10.0 * d1 * 2.0;
116 let expected = Dual::try_new(
117 20.0,
118 vec!["v0".to_string(), "v1".to_string()],
119 vec![20.0, 40.0],
120 )
121 .unwrap();
122 assert_eq!(result, expected)
123 }
124
125 #[test]
126 fn mul() {
127 let d1 = Dual::try_new(
128 1.0,
129 vec!["v0".to_string(), "v1".to_string()],
130 vec![1.0, 2.0],
131 )
132 .unwrap();
133 let d2 = Dual::try_new(
134 2.0,
135 vec!["v0".to_string(), "v2".to_string()],
136 vec![0.0, 3.0],
137 )
138 .unwrap();
139 let expected = Dual::try_new(
140 2.0,
141 vec!["v0".to_string(), "v1".to_string(), "v2".to_string()],
142 vec![2.0, 4.0, 3.0],
143 )
144 .unwrap();
145 let result = d1 * d2;
146 assert_eq!(result, expected)
147 }
148
149 #[test]
150 fn mul_f64_2() {
151 let d1 = Dual2::try_new(
152 1.0,
153 vec!["v0".to_string(), "v1".to_string()],
154 vec![1.0, 2.0],
155 Vec::new(),
156 )
157 .unwrap();
158 let result = 10.0 * d1 * 2.0;
159 let expected = Dual2::try_new(
160 20.0,
161 vec!["v0".to_string(), "v1".to_string()],
162 vec![20.0, 40.0],
163 Vec::new(),
164 )
165 .unwrap();
166 assert_eq!(result, expected)
167 }
168
169 #[test]
170 fn mul2() {
171 let d1 = Dual2::try_new(
172 1.0,
173 vec!["v0".to_string(), "v1".to_string()],
174 vec![1.0, 2.0],
175 Vec::new(),
176 )
177 .unwrap();
178 let d2 = Dual2::try_new(
179 2.0,
180 vec!["v0".to_string(), "v2".to_string()],
181 vec![0.0, 3.0],
182 Vec::new(),
183 )
184 .unwrap();
185 let expected = Dual2::try_new(
186 2.0,
187 vec!["v0".to_string(), "v1".to_string(), "v2".to_string()],
188 vec![2.0, 4.0, 3.0],
189 vec![0., 0., 1.5, 0., 0., 3., 1.5, 3., 0.],
190 )
191 .unwrap();
192 let result = d1 * d2;
193 assert_eq!(result, expected)
194 }
195
196 #[test]
197 fn test_enum() {
198 let f = Number::F64(2.0);
199 let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
200 assert_eq!(
201 &f * &d,
202 Number::Dual(Dual::try_new(6.0, vec!["x".to_string()], vec![2.0]).unwrap())
203 );
204
205 assert_eq!(
206 &d * &d,
207 Number::Dual(Dual::try_new(9.0, vec!["x".to_string()], vec![6.0]).unwrap())
208 );
209 }
210
211 #[test]
212 #[should_panic]
213 fn test_enum_panic() {
214 let d = Number::Dual2(Dual2::new(2.0, vec!["y".to_string()]));
215 let d2 = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
216 let _ = d * d2;
217 }
218
219 #[test]
220 fn test_enum_f64() {
221 let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
222 let res = 2.0_f64 * d;
223 assert_eq!(
224 res,
225 Number::Dual(Dual::new(3.0, vec!["x".to_string()]) * 2.0)
226 );
227 }
228}