1use crate::dual::linalg::fouter11_;
2use crate::dual::{Dual, Dual2, MathFuncs, Number, Vars};
3
4use num_traits::{Pow, Signed};
5use pyo3::{pyfunction, PyErr};
6use std::sync::Arc;
7
8#[pyfunction]
9pub(crate) fn _sabr_x0(
10 k: Number,
11 f: Number,
12 _t: Number,
13 a: Number,
14 b: Number,
15 _p: Number,
16 _v: Number,
17 derivative: u8,
18) -> Result<(Number, Option<Number>), PyErr> {
19 let x0 = 1_f64 / &k;
23 let x1 = 1_f64 / 24_f64 - &b / 24_f64;
24 let x2 = (&f * &x0).log();
25 let x3 = (1_f64 - &b).pow(4_f64);
26 let x4 = &x1 * (&x2).pow(2_f64) + (&x2).pow(4_f64) * &x3 / 1920_f64 + 1_f64;
27 let x5 = &b / 2_f64 - 0.5_f64;
28 let x6 = &a * (&f * &k).pow(&x5);
29
30 let x = &x6
31 / ((&x2).pow(4_f64) * (1_f64 - &b).pow(4_f64) / 1920_f64
32 + (&x2).pow(2_f64) * (1_f64 / 24_f64 - &b / 24_f64)
33 + 1_f64);
34
35 let dx: Option<Number> = match derivative {
36 1 => Some(
37 &x0 * x5 * &x6 / &x4
38 + &x6 * (2_f64 * &x0 * x1 * &x2 + &x0 * (&x2).pow(3_f64) * x3 / 480_f64)
39 / x4.pow(2_f64),
40 ),
41 2 => {
42 let y0 = &b - 1_f64;
43 let y2 = (&y0).pow(2_f64) * (&x2).pow(2_f64);
44 let y3 = (&y0).pow(4_f64) * (&x2).pow(4_f64) + 80_f64 * &y2 + 1920_f64;
45 Some(
46 960_f64 * &a * &y0 * (&f * &k).pow(x5) * (-8_f64 * y0 * x2 * (&y2 + 40_f64) + &y3)
47 / (f * y3.pow(2_f64)),
48 )
49 }
50 _ => None,
51 };
52
53 Ok((x, dx))
54}
55
56#[pyfunction]
57pub(crate) fn _sabr_x1(
58 k: Number,
59 f: Number,
60 t: Number,
61 a: Number,
62 b: Number,
63 p: Number,
64 v: Number,
65 derivative: u8,
66) -> Result<(Number, Option<Number>), PyErr> {
67 let x0 = 1_f64 / &k;
68 let x1 = &b / 2_f64 - 0.5_f64;
69 let x2 = &f * &k;
70 let x3 = &b - 1_f64;
71 let x = &t
72 * ((&a).pow(2_f64) * (&x2).pow(&x3) * (&x3).pow(2_f64) / 24_f64
73 + 0.25_f64 * &a * &b * &p * &v * (&x2).pow(&x1)
74 + (&v).pow(2_f64) * (2_f64 - 3_f64 * (&p).pow(2_f64)) / 24_f64)
75 + 1_f64;
76
77 let dx: Option<Number> = match derivative {
78 1 => Some(
79 &t * ((&a).pow(2_f64) * &x0 * (&x2).pow(&x3) * (&x3).pow(3_f64) / 24_f64
80 + 0.25 * &a * &b * p * &v * x0 * &x1 * x2.pow(x1)),
81 ),
82 2 => Some(
83 &a * &t
84 * &x3
85 * (&a * (&x3).pow(2_f64) * (&x2).pow(x3) + 3_f64 * b * p * v * x2.pow(x1))
86 / (24_f64 * f),
87 ),
88 _ => None,
89 };
90
91 Ok((x, dx))
92}
93
94#[pyfunction]
95pub(crate) fn _sabr_x2(
96 k: Number,
97 f: Number,
98 _t: Number,
99 a: Number,
100 b: Number,
101 p: Number,
102 v: Number,
103 derivative: u8,
104) -> Result<(Number, Option<Number>), PyErr> {
105 let x0 = 1_f64 / &k;
106 let x1 = (&f * &x0).log();
107 let x2 = 1_f64 / &a;
108 let x3 = &f * &k;
109 let x4 = &b / 2_f64 - 0.5_f64;
110 let x5 = (&x3).pow(-&x4);
111 let x6 = &v * &x2 * &x5;
112
113 let z = &x6 * &x1;
114 let chi = (((1_f64 - 2_f64 * &p * &z + &z * &z).pow(0.5_f64) + &z - &p) / (1_f64 - &p)).log();
115
116 let x: Number;
117 if z.abs() > 1e-15_f64 {
118 x = &z / χ
119 } else {
120 let p_f64 = f64::from(&p);
122
123 x = match &z {
124 Number::F64(_z) => Number::F64(1_f64),
125 Number::Dual(z_) => Number::Dual(Dual {
126 real: 1_f64,
127 dual: &z_.dual * p_f64 * -0.5_f64,
128 vars: Arc::clone(&z_.vars),
129 }),
130 Number::Dual2(z_) => {
131 let (z_cast, p_cast): (Dual2, Dual2) = match &p {
132 Number::F64(p_) => {
133 let temp = Dual2::new_from(z_, *p_, vec![]);
134 z_.to_union_vars(&temp, None)
135 }
136 Number::Dual(_) => panic!("Unexpected Dual/Dual2 type crossing in _sabr_x2"),
137 Number::Dual2(p_) => z_.to_union_vars(p_, None),
138 };
139 let f_z = -0.5_f64 * p_f64;
140 let f_zz = (2_f64 - 3_f64 * p_f64 * p_f64) / 6_f64;
142 let f_zp = -0.5_f64;
143 let mut dual2 = f_z * &z_cast.dual2.clone();
146 dual2 =
147 dual2 + 0.5_f64 * f_zz * fouter11_(&z_cast.dual.view(), &z_cast.dual.view());
148 dual2 = dual2
150 + 0.5_f64
151 * f_zp
152 * (fouter11_(&z_cast.dual.view(), &p_cast.dual.view())
153 + fouter11_(&p_cast.dual.view(), &z_cast.dual.view()));
154 Number::Dual2(Dual2 {
155 real: 1_f64,
156 vars: Arc::clone(&z_cast.vars),
157 dual: &z_cast.dual * p_f64 * -0.5_f64,
158 dual2,
159 })
160 }
161 };
162 }
163
164 let dx: Option<Number>;
165 match derivative {
166 1 => {
167 if z.abs() > 1e-15_f64 {
168 let x7 = &x1 * &x6;
169 let x8 = &p * &x7;
170 let x9 = (&x1).pow(2_f64);
171 let x10 = (&a).pow(-2_f64);
172 let x11 = (&v).pow(2_f64);
173 let x12 = &b - 1_f64;
174 let x13 = (&x3).pow(-&x12);
175 let x14 = &x10 * &x11 * &x13;
176 let x15 = (&x14 * &x9 - 2_f64 * &x8 + 1_f64).pow(0.5_f64);
177 let x16 = -&p + &x15 + &x7;
178 let x17 = (&x16 / (1_f64 - &p)).log();
179 let x18 = 1_f64 / &x17;
180 let x19 = &x0 * &x6;
181 let x20 = -&x4;
182 let x21 = 1.0 * &x0;
183
184 dx = Some(
185 &v * &x0 * &x1 * &x18 * &x2 * &x20 * &x5
186 - &x18 * &x19
187 - &x7
188 * (&x0 * &x20 * &x7 - x19
189 + (1.0_f64 * p * v * &x0 * x2 * x5
190 - 0.5_f64 * x0 * x10 * x11 * x12 * x13 * x9
191 - x1 * x14 * &x21
192 - x20 * x21 * x8)
193 / x15)
194 / (x16 * (&x17).pow(2_f64)),
195 )
196 } else {
197 let dx_dz = _sabr_dx2_dz(&z, &p);
198
199 let y0 = 1_f64 / &k;
200 let y1 = &b / 2_f64 - 0.5_f64;
201 let y2 = &v * &x0 / (&a * (&f * &k).pow(&y1));
202 let dz = -y2 * (y1 * (&f * y0).log() + 1_f64);
203
204 dx = Some(dx_dz * dz);
205 }
206 }
207 2 => {
208 if z.abs() > 1e-15_f64 {
209 let y0 = (&a).pow(2_f64);
210 let y1 = 1_f64 / &y0;
211 let y3 = (&x3).pow(-x4);
212 let y4 = &a * &p;
213 let y6 = &v * &x1;
214 let y7 = &y3 * &y6;
215 let y8 = &b - 1_f64;
216 let y9 = (&x3).pow(-&y8);
217 let y10 =
218 (&y1 * (&v * &v * &x1 * &x1 * &y9 + &y0 - 2_f64 * &y4 * &y7)).pow(0.5_f64);
219 let y11 = &a * (-&p + &y10) + &y7;
220 let y12 = &a * &y10;
221 let y13 = ((&a * &p - &y12 - &y7) / (&a * (&p - 1_f64))).log();
222 let y14 = &x1 * &y8 - 2_f64;
223 let y15 = -&y14;
224
225 dx = Some(
226 &v * &y1
227 * &y3
228 * (&y11 * &y12 * &y13 * &y15
229 + &y6 * (&y12 * &y14 * &y3 + &y14 * &y6 * &y9 + &y15 * &y3 * &y4))
230 / (2_f64 * &f * &y10 * &y11 * (&y13).pow(2_f64)),
231 )
232 } else {
233 let dx_dz = _sabr_dx2_dz(&z, &p);
234 let dz = &v * &x5 * (-(&b - 1_f64) * &x1 + 2_f64) / (2_f64 * &a * &f);
235 dx = Some(dx_dz * dz);
236 }
237 }
238 _ => dx = None,
239 };
240
241 Ok((x, dx))
242}
243
244fn _sabr_dx2_dz(z: &Number, p: &Number) -> Number {
245 let p_f64 = f64::from(p);
246 match z {
247 Number::F64(_) => Number::F64(-p_f64 / 2_f64),
248 Number::Dual(z_) => {
249 let (z_cast, p_cast): (Dual, Dual) = match &p {
250 Number::F64(p_) => {
251 let temp = Dual::new_from(z_, *p_, vec![]);
252 z_.to_union_vars(&temp, None)
253 }
254 Number::Dual(p_) => z_.to_union_vars(p_, None),
255 Number::Dual2(_) => panic!("Unexpected Dual/Dual2 type crossing in _sabr_x2"),
256 };
257 let mut dual = -0.5_f64 * &p_cast.dual;
258 dual = dual + (2_f64 - 3_f64 * p_f64 * p_f64) / 6_f64 * &z_cast.dual;
259 Number::Dual(Dual {
260 real: -0.5_f64 * p_f64,
261 vars: Arc::clone(&z_cast.vars),
262 dual,
263 })
264 }
265 Number::Dual2(z_) => {
266 let (z_cast, p_cast): (Dual2, Dual2) = match &p {
267 Number::F64(p_) => {
268 let temp = Dual2::new_from(z_, *p_, vec![]);
269 z_.to_union_vars(&temp, None)
270 }
271 Number::Dual(_) => panic!("Unexpected Dual/Dual2 type crossing in _sabr_x2"),
272 Number::Dual2(p_) => z_.to_union_vars(p_, None),
273 };
274 let mut dual = -0.5_f64 * &p_cast.dual;
275 dual = dual + (2_f64 - 3_f64 * p_f64 * p_f64) / 6_f64 * &z_cast.dual;
276 let mut dual2 = (2_f64 - 3_f64 * p_f64 * p_f64) / 6_f64 * &z_cast.dual2;
277 dual2 = dual2 - 0.5_f64 * &p_cast.dual2;
278 dual2 = dual2
279 + p_f64 * (5_f64 - 6_f64 * p_f64 * p_f64) / 8_f64
280 * fouter11_(&z_cast.dual.view(), &z_cast.dual.view());
281 dual2 = dual2
282 - 0.5_f64
283 * p_f64
284 * (fouter11_(&z_cast.dual.view(), &p_cast.dual.view())
285 + fouter11_(&p_cast.dual.view(), &z_cast.dual.view()));
286 Number::Dual2(Dual2 {
287 real: -0.5_f64 * p_f64,
288 vars: Arc::clone(&z_cast.vars),
289 dual: dual,
290 dual2: dual2,
291 })
292 }
293 }
294}