1use crate::dual::dual::{Dual, Dual2};
2use crate::dual::enums::Number;
3use crate::dual::linalg::fouter11_;
4use num_traits::Pow;
5use statrs::distribution::{ContinuousCDF, Normal};
6use std::f64::consts::PI;
7use std::sync::Arc;
8
9pub trait MathFuncs {
11 fn exp(&self) -> Self;
13 fn log(&self) -> Self;
15 fn norm_cdf(&self) -> Self;
17 fn inv_norm_cdf(&self) -> Self;
19}
20
21impl MathFuncs for Dual {
22 fn exp(&self) -> Self {
23 let c = self.real.exp();
24 Dual {
25 real: c,
26 vars: Arc::clone(&self.vars),
27 dual: c * &self.dual,
28 }
29 }
30 fn log(&self) -> Self {
31 Dual {
32 real: self.real.ln(),
33 vars: Arc::clone(&self.vars),
34 dual: (1.0 / self.real) * &self.dual,
35 }
36 }
37 fn norm_cdf(&self) -> Self {
38 let n = Normal::new(0.0, 1.0).unwrap();
39 let base = n.cdf(self.real);
40 let scalar = 1.0 / (2.0 * PI).sqrt() * (-0.5_f64 * self.real.pow(2.0_f64)).exp();
41 Dual {
42 real: base,
43 vars: Arc::clone(&self.vars),
44 dual: scalar * &self.dual,
45 }
46 }
47 fn inv_norm_cdf(&self) -> Self {
48 let n = Normal::new(0.0, 1.0).unwrap();
49 let base = n.inverse_cdf(self.real);
50 let scalar = (2.0 * PI).sqrt() * (0.5_f64 * base.pow(2.0_f64)).exp();
51 Dual {
52 real: base,
53 vars: Arc::clone(&self.vars),
54 dual: scalar * &self.dual,
55 }
56 }
57}
58
59impl MathFuncs for Dual2 {
60 fn exp(&self) -> Self {
61 let c = self.real.exp();
62 Dual2 {
63 real: c,
64 vars: Arc::clone(&self.vars),
65 dual: c * &self.dual,
66 dual2: c * (&self.dual2 + 0.5 * fouter11_(&self.dual.view(), &self.dual.view())),
67 }
68 }
69 fn log(&self) -> Self {
70 let scalar = 1.0 / self.real;
71 Dual2 {
72 real: self.real.ln(),
73 vars: Arc::clone(&self.vars),
74 dual: scalar * &self.dual,
75 dual2: scalar * &self.dual2
76 - fouter11_(&self.dual.view(), &self.dual.view()) * 0.5 * (scalar * scalar),
77 }
78 }
79 fn norm_cdf(&self) -> Self {
80 let n = Normal::new(0.0, 1.0).unwrap();
81 let base = n.cdf(self.real);
82 let scalar = 1.0 / (2.0 * PI).sqrt() * (-0.5_f64 * self.real.pow(2.0_f64)).exp();
83 let scalar2 = scalar * -self.real;
84 let cross_beta = fouter11_(&self.dual.view(), &self.dual.view());
85 Dual2 {
86 real: base,
87 vars: Arc::clone(&self.vars),
88 dual: scalar * &self.dual,
89 dual2: scalar * &self.dual2 + 0.5_f64 * scalar2 * cross_beta,
90 }
91 }
92 fn inv_norm_cdf(&self) -> Self {
93 let n = Normal::new(0.0, 1.0).unwrap();
94 let base = n.inverse_cdf(self.real);
95 let scalar = (2.0 * PI).sqrt() * (0.5_f64 * base.pow(2.0_f64)).exp();
96 let scalar2 = scalar.pow(2.0_f64) * base;
97 let cross_beta = fouter11_(&self.dual.view(), &self.dual.view());
98 Dual2 {
99 real: base,
100 vars: Arc::clone(&self.vars),
101 dual: scalar * &self.dual,
102 dual2: scalar * &self.dual2 + 0.5_f64 * scalar2 * cross_beta,
103 }
104 }
105}
106
107impl MathFuncs for f64 {
108 fn inv_norm_cdf(&self) -> Self {
109 Normal::new(0.0, 1.0).unwrap().inverse_cdf(*self)
110 }
111 fn norm_cdf(&self) -> Self {
112 Normal::new(0.0, 1.0).unwrap().cdf(*self)
113 }
114 fn exp(&self) -> Self {
115 f64::exp(*self)
116 }
117 fn log(&self) -> Self {
118 f64::ln(*self)
119 }
120}
121
122macro_rules! math_func {
123 ($self: ident, $name: ident) => {
124 match $self {
125 Number::F64(f) => Number::F64(f.$name()),
126 Number::Dual(d) => Number::Dual(d.$name()),
127 Number::Dual2(d) => Number::Dual2(d.$name()),
128 }
129 };
130}
131
132impl MathFuncs for Number {
133 fn inv_norm_cdf(&self) -> Self {
134 math_func!(self, inv_norm_cdf)
135 }
136 fn norm_cdf(&self) -> Self {
137 math_func!(self, norm_cdf)
138 }
139 fn exp(&self) -> Self {
140 math_func!(self, exp)
141 }
142 fn log(&self) -> Self {
143 math_func!(self, log)
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn exp() {
153 let d1 = Dual::try_new(
154 1.0,
155 vec!["v0".to_string(), "v1".to_string()],
156 vec![1.0, 2.0],
157 )
158 .unwrap();
159 let result = d1.exp();
160 assert!(Arc::ptr_eq(&d1.vars, &result.vars));
161 let c = 1.0_f64.exp();
162 let expected = Dual::try_new(
163 c,
164 vec!["v0".to_string(), "v1".to_string()],
165 vec![1.0 * c, 2.0 * c],
166 )
167 .unwrap();
168 assert_eq!(result, expected);
169 }
170
171 #[test]
172 fn log() {
173 let d1 = Dual::try_new(
174 1.0,
175 vec!["v0".to_string(), "v1".to_string()],
176 vec![1.0, 2.0],
177 )
178 .unwrap();
179 let result = d1.log();
180 assert!(Arc::ptr_eq(&d1.vars, &result.vars));
181 let c = 1.0_f64.ln();
182 let expected =
183 Dual::try_new(c, vec!["v0".to_string(), "v1".to_string()], vec![1.0, 2.0]).unwrap();
184 assert_eq!(result, expected);
185 }
186
187 #[test]
188 fn exp2() {
189 let d1 = Dual2::try_new(
190 1.0,
191 vec!["v0".to_string(), "v1".to_string()],
192 vec![1.0, 2.0],
193 Vec::new(),
194 )
195 .unwrap();
196 let result = d1.exp();
197 assert!(Arc::ptr_eq(&d1.vars, &result.vars));
198 let c = 1.0_f64.exp();
199 let expected = Dual2::try_new(
200 c,
201 vec!["v0".to_string(), "v1".to_string()],
202 vec![1.0 * c, 2.0 * c],
203 vec![
204 1.0_f64.exp() * 0.5,
205 1.0_f64.exp(),
206 1.0_f64.exp(),
207 1.0_f64.exp() * 2.,
208 ],
209 )
210 .unwrap();
211 assert_eq!(result, expected);
212 }
213
214 #[test]
215 fn log2() {
216 let d1 = Dual2::try_new(
217 1.0,
218 vec!["v0".to_string(), "v1".to_string()],
219 vec![1.0, 2.0],
220 Vec::new(),
221 )
222 .unwrap();
223 let result = d1.log();
224 assert!(Arc::ptr_eq(&d1.vars, &result.vars));
225 let c = 1.0_f64.ln();
226 let expected = Dual2::try_new(
227 c,
228 vec!["v0".to_string(), "v1".to_string()],
229 vec![1.0, 2.0],
230 vec![-0.5, -1.0, -1.0, -2.0],
231 )
232 .unwrap();
233 println!("{:?}", result.dual2);
234 assert_eq!(result, expected);
235 }
236}