rateslib/dual/dual_ops/
math_funcs.rs

1use crate::dual::dual::{Dual, Dual2};
2use crate::dual::enums::Number;
3use crate::dual::linalg::fouter11_;
4use num_traits::Pow;
5use statrs::distribution::{ContinuousCDF, Normal};
6use std::f64::consts::PI;
7use std::sync::Arc;
8
9/// Functions for common mathematical operations.
10pub trait MathFuncs {
11    /// Return the exponential of a value.
12    fn exp(&self) -> Self;
13    /// Return the natural logarithm of a value.
14    fn log(&self) -> Self;
15    /// Return the standard normal cumulative distribution function of a value.
16    fn norm_cdf(&self) -> Self;
17    /// Return the inverse standard normal cumulative distribution function of a value.
18    fn inv_norm_cdf(&self) -> Self;
19}
20
21impl MathFuncs for Dual {
22    fn exp(&self) -> Self {
23        let c = self.real.exp();
24        Dual {
25            real: c,
26            vars: Arc::clone(&self.vars),
27            dual: c * &self.dual,
28        }
29    }
30    fn log(&self) -> Self {
31        Dual {
32            real: self.real.ln(),
33            vars: Arc::clone(&self.vars),
34            dual: (1.0 / self.real) * &self.dual,
35        }
36    }
37    fn norm_cdf(&self) -> Self {
38        let n = Normal::new(0.0, 1.0).unwrap();
39        let base = n.cdf(self.real);
40        let scalar = 1.0 / (2.0 * PI).sqrt() * (-0.5_f64 * self.real.pow(2.0_f64)).exp();
41        Dual {
42            real: base,
43            vars: Arc::clone(&self.vars),
44            dual: scalar * &self.dual,
45        }
46    }
47    fn inv_norm_cdf(&self) -> Self {
48        let n = Normal::new(0.0, 1.0).unwrap();
49        let base = n.inverse_cdf(self.real);
50        let scalar = (2.0 * PI).sqrt() * (0.5_f64 * base.pow(2.0_f64)).exp();
51        Dual {
52            real: base,
53            vars: Arc::clone(&self.vars),
54            dual: scalar * &self.dual,
55        }
56    }
57}
58
59impl MathFuncs for Dual2 {
60    fn exp(&self) -> Self {
61        let c = self.real.exp();
62        Dual2 {
63            real: c,
64            vars: Arc::clone(&self.vars),
65            dual: c * &self.dual,
66            dual2: c * (&self.dual2 + 0.5 * fouter11_(&self.dual.view(), &self.dual.view())),
67        }
68    }
69    fn log(&self) -> Self {
70        let scalar = 1.0 / self.real;
71        Dual2 {
72            real: self.real.ln(),
73            vars: Arc::clone(&self.vars),
74            dual: scalar * &self.dual,
75            dual2: scalar * &self.dual2
76                - fouter11_(&self.dual.view(), &self.dual.view()) * 0.5 * (scalar * scalar),
77        }
78    }
79    fn norm_cdf(&self) -> Self {
80        let n = Normal::new(0.0, 1.0).unwrap();
81        let base = n.cdf(self.real);
82        let scalar = 1.0 / (2.0 * PI).sqrt() * (-0.5_f64 * self.real.pow(2.0_f64)).exp();
83        let scalar2 = scalar * -self.real;
84        let cross_beta = fouter11_(&self.dual.view(), &self.dual.view());
85        Dual2 {
86            real: base,
87            vars: Arc::clone(&self.vars),
88            dual: scalar * &self.dual,
89            dual2: scalar * &self.dual2 + 0.5_f64 * scalar2 * cross_beta,
90        }
91    }
92    fn inv_norm_cdf(&self) -> Self {
93        let n = Normal::new(0.0, 1.0).unwrap();
94        let base = n.inverse_cdf(self.real);
95        let scalar = (2.0 * PI).sqrt() * (0.5_f64 * base.pow(2.0_f64)).exp();
96        let scalar2 = scalar.pow(2.0_f64) * base;
97        let cross_beta = fouter11_(&self.dual.view(), &self.dual.view());
98        Dual2 {
99            real: base,
100            vars: Arc::clone(&self.vars),
101            dual: scalar * &self.dual,
102            dual2: scalar * &self.dual2 + 0.5_f64 * scalar2 * cross_beta,
103        }
104    }
105}
106
107impl MathFuncs for f64 {
108    fn inv_norm_cdf(&self) -> Self {
109        Normal::new(0.0, 1.0).unwrap().inverse_cdf(*self)
110    }
111    fn norm_cdf(&self) -> Self {
112        Normal::new(0.0, 1.0).unwrap().cdf(*self)
113    }
114    fn exp(&self) -> Self {
115        f64::exp(*self)
116    }
117    fn log(&self) -> Self {
118        f64::ln(*self)
119    }
120}
121
122macro_rules! math_func {
123    ($self: ident, $name: ident) => {
124        match $self {
125            Number::F64(f) => Number::F64(f.$name()),
126            Number::Dual(d) => Number::Dual(d.$name()),
127            Number::Dual2(d) => Number::Dual2(d.$name()),
128        }
129    };
130}
131
132impl MathFuncs for Number {
133    fn inv_norm_cdf(&self) -> Self {
134        math_func!(self, inv_norm_cdf)
135    }
136    fn norm_cdf(&self) -> Self {
137        math_func!(self, norm_cdf)
138    }
139    fn exp(&self) -> Self {
140        math_func!(self, exp)
141    }
142    fn log(&self) -> Self {
143        math_func!(self, log)
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn exp() {
153        let d1 = Dual::try_new(
154            1.0,
155            vec!["v0".to_string(), "v1".to_string()],
156            vec![1.0, 2.0],
157        )
158        .unwrap();
159        let result = d1.exp();
160        assert!(Arc::ptr_eq(&d1.vars, &result.vars));
161        let c = 1.0_f64.exp();
162        let expected = Dual::try_new(
163            c,
164            vec!["v0".to_string(), "v1".to_string()],
165            vec![1.0 * c, 2.0 * c],
166        )
167        .unwrap();
168        assert_eq!(result, expected);
169    }
170
171    #[test]
172    fn log() {
173        let d1 = Dual::try_new(
174            1.0,
175            vec!["v0".to_string(), "v1".to_string()],
176            vec![1.0, 2.0],
177        )
178        .unwrap();
179        let result = d1.log();
180        assert!(Arc::ptr_eq(&d1.vars, &result.vars));
181        let c = 1.0_f64.ln();
182        let expected =
183            Dual::try_new(c, vec!["v0".to_string(), "v1".to_string()], vec![1.0, 2.0]).unwrap();
184        assert_eq!(result, expected);
185    }
186
187    #[test]
188    fn exp2() {
189        let d1 = Dual2::try_new(
190            1.0,
191            vec!["v0".to_string(), "v1".to_string()],
192            vec![1.0, 2.0],
193            Vec::new(),
194        )
195        .unwrap();
196        let result = d1.exp();
197        assert!(Arc::ptr_eq(&d1.vars, &result.vars));
198        let c = 1.0_f64.exp();
199        let expected = Dual2::try_new(
200            c,
201            vec!["v0".to_string(), "v1".to_string()],
202            vec![1.0 * c, 2.0 * c],
203            vec![
204                1.0_f64.exp() * 0.5,
205                1.0_f64.exp(),
206                1.0_f64.exp(),
207                1.0_f64.exp() * 2.,
208            ],
209        )
210        .unwrap();
211        assert_eq!(result, expected);
212    }
213
214    #[test]
215    fn log2() {
216        let d1 = Dual2::try_new(
217            1.0,
218            vec!["v0".to_string(), "v1".to_string()],
219            vec![1.0, 2.0],
220            Vec::new(),
221        )
222        .unwrap();
223        let result = d1.log();
224        assert!(Arc::ptr_eq(&d1.vars, &result.vars));
225        let c = 1.0_f64.ln();
226        let expected = Dual2::try_new(
227            c,
228            vec!["v0".to_string(), "v1".to_string()],
229            vec![1.0, 2.0],
230            vec![-0.5, -1.0, -1.0, -2.0],
231        )
232        .unwrap();
233        println!("{:?}", result.dual2);
234        assert_eq!(result, expected);
235    }
236}