1use crate::dual::dual::{Dual, Dual2, Vars, VarsRelationship};
2use crate::dual::enums::Number;
3use auto_ops::{impl_op_ex, impl_op_ex_commutative};
4use std::sync::Arc;
5
6impl_op_ex_commutative!(+ |a: &Dual, b: &f64| -> Dual { Dual {vars: Arc::clone(&a.vars), real: a.real + b, dual: a.dual.clone()} });
8impl_op_ex_commutative!(+ |a: &Dual2, b: &f64| -> Dual2 {
9 Dual2 {vars: Arc::clone(&a.vars), real: a.real + b, dual: a.dual.clone(), dual2: a.dual2.clone()}
10});
11
12impl_op_ex!(+ |a: &Dual, b: &Dual| -> Dual {
14 let state = a.vars_cmp(b.vars());
15 match state {
16 VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
17 Dual {real: a.real + b.real, dual: &a.dual + &b.dual, vars: Arc::clone(&a.vars)}
18 }
19 _ => {
20 let (x, y) = a.to_union_vars(b, Some(state));
21 Dual {real: x.real + y.real, dual: &x.dual + &y.dual, vars: Arc::clone(&x.vars)}
22 }
23 }
24});
25
26impl_op_ex!(+ |a: &Dual2, b: &Dual2| -> Dual2 {
28 let state = a.vars_cmp(b.vars());
29 match state {
30 VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
31 Dual2 {
32 real: a.real + b.real,
33 dual: &a.dual + &b.dual,
34 dual2: &a.dual2 + &b.dual2,
35 vars: Arc::clone(&a.vars)}
36 }
37 _ => {
38 let (x, y) = a.to_union_vars(b, Some(state));
39 Dual2 {
40 real: x.real + y.real,
41 dual: &x.dual + &y.dual,
42 dual2: &x.dual2 + &y.dual2,
43 vars: Arc::clone(&x.vars)}
44 }
45 }
46});
47
48impl_op_ex!(+ |a: &Number, b: &Number| -> Number {
50 match (a,b) {
51 (Number::F64(f), Number::F64(f2)) => Number::F64(f + f2),
52 (Number::F64(f), Number::Dual(d2)) => Number::Dual(f + d2),
53 (Number::F64(f), Number::Dual2(d2)) => Number::Dual2(f + d2),
54 (Number::Dual(d), Number::F64(f2)) => Number::Dual(d + f2),
55 (Number::Dual(d), Number::Dual(d2)) => Number::Dual(d + d2),
56 (Number::Dual(_), Number::Dual2(_)) => panic!("Cannot mix dual types: Dual + Dual2"),
57 (Number::Dual2(d), Number::F64(f2)) => Number::Dual2(d + f2),
58 (Number::Dual2(_), Number::Dual(_)) => panic!("Cannot mix dual types: Dual2 + Dual"),
59 (Number::Dual2(d), Number::Dual2(d2)) => Number::Dual2(d + d2),
60 }
61});
62
63impl_op_ex_commutative!(+ |a: &Number, b: &f64| -> Number {
65 match a {
66 Number::F64(f) => Number::F64(f + b),
67 Number::Dual(d) => Number::Dual(d + b),
68 Number::Dual2(d) => Number::Dual2(d + b),
69 }
70});
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[test]
77 fn add_f64() {
78 let d1 = Dual::try_new(
79 1.0,
80 vec!["v0".to_string(), "v1".to_string()],
81 vec![1.0, 2.0],
82 )
83 .unwrap();
84 let result = 10.0 + d1 + 15.0;
85 let expected = Dual::try_new(
86 26.0,
87 vec!["v0".to_string(), "v1".to_string()],
88 vec![1.0, 2.0],
89 )
90 .unwrap();
91 assert_eq!(result, expected)
92 }
93
94 #[test]
95 fn add() {
96 let d1 = Dual::try_new(
97 1.0,
98 vec!["v0".to_string(), "v1".to_string()],
99 vec![1.0, 2.0],
100 )
101 .unwrap();
102 let d2 = Dual::try_new(
103 2.0,
104 vec!["v0".to_string(), "v2".to_string()],
105 vec![0.0, 3.0],
106 )
107 .unwrap();
108 let expected = Dual::try_new(
109 3.0,
110 vec!["v0".to_string(), "v1".to_string(), "v2".to_string()],
111 vec![1.0, 2.0, 3.0],
112 )
113 .unwrap();
114 let result = d1 + d2;
115 assert_eq!(result, expected)
116 }
117
118 #[test]
119 fn add_f64_2() {
120 let d1 = Dual2::try_new(
121 1.0,
122 vec!["v0".to_string(), "v1".to_string()],
123 vec![1.0, 2.0],
124 Vec::new(),
125 )
126 .unwrap();
127 let result = 10.0 + d1 + 15.0;
128 let expected = Dual2::try_new(
129 26.0,
130 vec!["v0".to_string(), "v1".to_string()],
131 vec![1.0, 2.0],
132 Vec::new(),
133 )
134 .unwrap();
135 assert_eq!(result, expected)
136 }
137
138 #[test]
139 fn add2() {
140 let d1 = Dual2::try_new(
141 1.0,
142 vec!["v0".to_string(), "v1".to_string()],
143 vec![1.0, 2.0],
144 Vec::new(),
145 )
146 .unwrap();
147 let d2 = Dual2::try_new(
148 2.0,
149 vec!["v0".to_string(), "v2".to_string()],
150 vec![0.0, 3.0],
151 Vec::new(),
152 )
153 .unwrap();
154 let expected = Dual2::try_new(
155 3.0,
156 vec!["v0".to_string(), "v1".to_string(), "v2".to_string()],
157 vec![1.0, 2.0, 3.0],
158 Vec::new(),
159 )
160 .unwrap();
161 let result = d1 + d2;
162 assert_eq!(result, expected)
163 }
164
165 #[test]
166 fn test_enum() {
167 let f = Number::F64(2.0);
168 let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
169 assert_eq!(&f + &d, Number::Dual(Dual::new(5.0, vec!["x".to_string()])));
170
171 assert_eq!(
172 &d + &d,
173 Number::Dual(Dual::try_new(6.0, vec!["x".to_string()], vec![2.0]).unwrap())
174 );
175 }
176
177 #[test]
178 #[should_panic]
179 fn test_enum_panic() {
180 let d = Number::Dual2(Dual2::new(2.0, vec!["y".to_string()]));
181 let d2 = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
182 let _ = d + d2;
183 }
184
185 #[test]
186 fn test_enum_f64() {
187 let d = Number::Dual(Dual::new(3.0, vec!["x".to_string()]));
188 let res = 2.5_f64 + d;
189 assert_eq!(res, Number::Dual(Dual::new(5.5, vec!["x".to_string()])));
190 }
191}