rateslib/dual/dual_ops/
signed.rs

1use crate::dual::dual::{Dual, Dual2};
2use crate::dual::enums::Number;
3use num_traits::Signed;
4use std::sync::Arc;
5
6/// Sign for `Dual` is evaluated in terms of the `real` component.
7impl Signed for Dual {
8    /// Determine the absolute value of `Dual`.
9    ///
10    /// If `real` is negative the returned `Dual` will negate both its `real` value and
11    /// `dual`.
12    ///
13    /// <div class="warning">This behaviour is undefined at zero. The derivative of the `abs` function is
14    /// not defined there and care needs to be taken when implying gradients.</div>
15    fn abs(&self) -> Self {
16        if self.real > 0.0 {
17            Dual {
18                real: self.real,
19                vars: Arc::clone(&self.vars),
20                dual: self.dual.clone(),
21            }
22        } else {
23            Dual {
24                real: -self.real,
25                vars: Arc::clone(&self.vars),
26                dual: -1.0 * &self.dual,
27            }
28        }
29    }
30
31    fn abs_sub(&self, other: &Self) -> Self {
32        if self <= other {
33            Dual::new(0.0, Vec::new())
34        } else {
35            self - other
36        }
37    }
38
39    fn signum(&self) -> Self {
40        Dual::new(self.real.signum(), Vec::new())
41    }
42
43    fn is_positive(&self) -> bool {
44        self.real.is_sign_positive()
45    }
46
47    fn is_negative(&self) -> bool {
48        self.real.is_sign_negative()
49    }
50}
51
52impl Signed for Dual2 {
53    fn abs(&self) -> Self {
54        if self.real > 0.0 {
55            Dual2 {
56                real: self.real,
57                vars: Arc::clone(&self.vars),
58                dual: self.dual.clone(),
59                dual2: self.dual2.clone(),
60            }
61        } else {
62            Dual2 {
63                real: -self.real,
64                vars: Arc::clone(&self.vars),
65                dual: -1.0 * &self.dual,
66                dual2: -1.0 * &self.dual2,
67            }
68        }
69    }
70
71    fn abs_sub(&self, other: &Self) -> Self {
72        if self <= other {
73            Dual2::new(0.0, Vec::new())
74        } else {
75            self - other
76        }
77    }
78
79    fn signum(&self) -> Self {
80        Dual2::new(self.real.signum(), Vec::new())
81    }
82
83    fn is_positive(&self) -> bool {
84        self.real.is_sign_positive()
85    }
86
87    fn is_negative(&self) -> bool {
88        self.real.is_sign_negative()
89    }
90}
91
92impl Signed for Number {
93    fn abs(&self) -> Self {
94        match self {
95            Number::F64(f) => Number::F64(f.abs()),
96            Number::Dual(d) => Number::Dual(d.abs()),
97            Number::Dual2(d) => Number::Dual2(d.abs()),
98        }
99    }
100
101    fn abs_sub(&self, other: &Self) -> Self {
102        match (self, other) {
103            (Number::F64(f), Number::F64(f2)) => Number::F64(f.abs_sub(f2)),
104            (Number::F64(f), Number::Dual(d2)) => Number::Dual(Dual::new(*f, vec![]).abs_sub(d2)),
105            (Number::F64(f), Number::Dual2(d2)) => {
106                Number::Dual2(Dual2::new(*f, vec![]).abs_sub(d2))
107            }
108            (Number::Dual(d), Number::F64(f2)) => Number::Dual(d.abs_sub(&Dual::new(*f2, vec![]))),
109            (Number::Dual(d), Number::Dual(d2)) => Number::Dual(d.abs_sub(d2)),
110            (Number::Dual(_), Number::Dual2(_)) => {
111                panic!("Cannot mix dual types: Dual / Dual2")
112            }
113            (Number::Dual2(d), Number::F64(f2)) => {
114                Number::Dual2(d.abs_sub(&Dual2::new(*f2, vec![])))
115            }
116            (Number::Dual2(_), Number::Dual(_)) => {
117                panic!("Cannot mix dual types: Dual2 / Dual")
118            }
119            (Number::Dual2(d), Number::Dual2(d2)) => Number::Dual2(d.abs_sub(d2)),
120        }
121    }
122
123    fn signum(&self) -> Self {
124        match self {
125            Number::F64(f) => Number::F64(f.signum()),
126            Number::Dual(d) => Number::Dual(d.signum()),
127            Number::Dual2(d) => Number::Dual2(d.signum()),
128        }
129    }
130
131    fn is_positive(&self) -> bool {
132        match self {
133            Number::F64(f) => f.is_positive(),
134            Number::Dual(d) => d.is_positive(),
135            Number::Dual2(d) => d.is_positive(),
136        }
137    }
138
139    fn is_negative(&self) -> bool {
140        match self {
141            Number::F64(f) => f.is_negative(),
142            Number::Dual(d) => d.is_negative(),
143            Number::Dual2(d) => d.is_negative(),
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use num_traits::{One, Zero};
152
153    #[test]
154    fn signed() {
155        let d1 = Dual::new(3.0, vec!["x".to_string()]);
156        let d2 = Dual::new(-2.0, vec!["x".to_string()]);
157
158        assert!(d2.is_negative());
159        assert!(d1.is_positive());
160        assert_eq!(d2.signum(), -1.0 * Dual::one());
161        assert_eq!(d1.signum(), Dual::one());
162        assert_eq!(d1.abs_sub(&d2), Dual::new(5.0, Vec::new()));
163        assert_eq!(d2.abs_sub(&d1), Dual::zero());
164    }
165
166    #[test]
167    fn signed_2() {
168        let d1 = Dual2::new(3.0, vec!["x".to_string()]);
169        let d2 = Dual2::new(-2.0, vec!["x".to_string()]);
170
171        assert!(d2.is_negative());
172        assert!(d1.is_positive());
173        assert_eq!(d2.signum(), -1.0 * Dual2::one());
174        assert_eq!(d1.signum(), Dual2::one());
175        assert_eq!(d1.abs_sub(&d2), Dual2::new(5.0, Vec::new()));
176        assert_eq!(d2.abs_sub(&d1), Dual2::zero());
177    }
178
179    #[test]
180    fn abs() {
181        let d1 = Dual::try_new(
182            -2.0,
183            vec!["v0".to_string(), "v1".to_string()],
184            vec![1.0, 2.0],
185        )
186        .unwrap();
187        let result = d1.abs();
188        let expected = Dual::try_new(
189            2.0,
190            vec!["v0".to_string(), "v1".to_string()],
191            vec![-1.0, -2.0],
192        )
193        .unwrap();
194        assert_eq!(result, expected);
195
196        let result = d1.abs();
197        assert_eq!(result, expected);
198    }
199
200    #[test]
201    fn abs2() {
202        let d1 = Dual2::try_new(
203            -2.0,
204            vec!["v0".to_string(), "v1".to_string()],
205            vec![1.0, 2.0],
206            Vec::new(),
207        )
208        .unwrap();
209        let result = d1.abs();
210        let expected = Dual2::try_new(
211            2.0,
212            vec!["v0".to_string(), "v1".to_string()],
213            vec![-1.0, -2.0],
214            Vec::new(),
215        )
216        .unwrap();
217        assert_eq!(result, expected);
218
219        let result = result.abs();
220        assert_eq!(result, expected);
221    }
222
223    #[test]
224    fn test_enum() {
225        let d = Number::Dual(Dual::new(-2.5, vec!["x".to_string()]));
226        assert!(!d.is_positive());
227        assert!(d.is_negative());
228        assert_eq!(
229            d.abs(),
230            Number::Dual(Dual::try_new(2.5, vec!["x".to_string()], vec![-1.0]).unwrap())
231        );
232    }
233}