rateslib/fx_volatility/
sabr_funcs.rs

1use crate::dual::linalg::fouter11_;
2use crate::dual::{Dual, Dual2, MathFuncs, Number, Vars};
3
4use num_traits::{Pow, Signed};
5use pyo3::{pyfunction, PyErr};
6use std::sync::Arc;
7
8#[pyfunction]
9pub(crate) fn _sabr_x0(
10    k: Number,
11    f: Number,
12    _t: Number,
13    a: Number,
14    b: Number,
15    _p: Number,
16    _v: Number,
17    derivative: u8,
18) -> Result<(Number, Option<Number>), PyErr> {
19    // X0 = a / ((fk)^((1-b)/2) * (1 + (1-b)^2/24 ln^2(f/k) + (1-b)^4/1920 ln^4(f/k) )
20    //If ``derivative`` is 1 also returns dX0/dk, calculated using sympy.
21    //If ``derivative`` is 2 also returns dX0/df, calculated using sympy.
22    let x0 = 1_f64 / &k;
23    let x1 = 1_f64 / 24_f64 - &b / 24_f64;
24    let x2 = (&f * &x0).log();
25    let x3 = (1_f64 - &b).pow(4_f64);
26    let x4 = &x1 * (&x2).pow(2_f64) + (&x2).pow(4_f64) * &x3 / 1920_f64 + 1_f64;
27    let x5 = &b / 2_f64 - 0.5_f64;
28    let x6 = &a * (&f * &k).pow(&x5);
29
30    let x = &x6
31        / ((&x2).pow(4_f64) * (1_f64 - &b).pow(4_f64) / 1920_f64
32            + (&x2).pow(2_f64) * (1_f64 / 24_f64 - &b / 24_f64)
33            + 1_f64);
34
35    let dx: Option<Number> = match derivative {
36        1 => Some(
37            &x0 * x5 * &x6 / &x4
38                + &x6 * (2_f64 * &x0 * x1 * &x2 + &x0 * (&x2).pow(3_f64) * x3 / 480_f64)
39                    / x4.pow(2_f64),
40        ),
41        2 => {
42            let y0 = &b - 1_f64;
43            let y2 = (&y0).pow(2_f64) * (&x2).pow(2_f64);
44            let y3 = (&y0).pow(4_f64) * (&x2).pow(4_f64) + 80_f64 * &y2 + 1920_f64;
45            Some(
46                960_f64 * &a * &y0 * (&f * &k).pow(x5) * (-8_f64 * y0 * x2 * (&y2 + 40_f64) + &y3)
47                    / (f * y3.pow(2_f64)),
48            )
49        }
50        _ => None,
51    };
52
53    Ok((x, dx))
54}
55
56#[pyfunction]
57pub(crate) fn _sabr_x1(
58    k: Number,
59    f: Number,
60    t: Number,
61    a: Number,
62    b: Number,
63    p: Number,
64    v: Number,
65    derivative: u8,
66) -> Result<(Number, Option<Number>), PyErr> {
67    let x0 = 1_f64 / &k;
68    let x1 = &b / 2_f64 - 0.5_f64;
69    let x2 = &f * &k;
70    let x3 = &b - 1_f64;
71    let x = &t
72        * ((&a).pow(2_f64) * (&x2).pow(&x3) * (&x3).pow(2_f64) / 24_f64
73            + 0.25_f64 * &a * &b * &p * &v * (&x2).pow(&x1)
74            + (&v).pow(2_f64) * (2_f64 - 3_f64 * (&p).pow(2_f64)) / 24_f64)
75        + 1_f64;
76
77    let dx: Option<Number> = match derivative {
78        1 => Some(
79            &t * ((&a).pow(2_f64) * &x0 * (&x2).pow(&x3) * (&x3).pow(3_f64) / 24_f64
80                + 0.25 * &a * &b * p * &v * x0 * &x1 * x2.pow(x1)),
81        ),
82        2 => Some(
83            &a * &t
84                * &x3
85                * (&a * (&x3).pow(2_f64) * (&x2).pow(x3) + 3_f64 * b * p * v * x2.pow(x1))
86                / (24_f64 * f),
87        ),
88        _ => None,
89    };
90
91    Ok((x, dx))
92}
93
94#[pyfunction]
95pub(crate) fn _sabr_x2(
96    k: Number,
97    f: Number,
98    _t: Number,
99    a: Number,
100    b: Number,
101    p: Number,
102    v: Number,
103    derivative: u8,
104) -> Result<(Number, Option<Number>), PyErr> {
105    let x0 = 1_f64 / &k;
106    let x1 = (&f * &x0).log();
107    let x2 = 1_f64 / &a;
108    let x3 = &f * &k;
109    let x4 = &b / 2_f64 - 0.5_f64;
110    let x5 = (&x3).pow(-&x4);
111    let x6 = &v * &x2 * &x5;
112
113    let z = &x6 * &x1;
114    let chi = (((1_f64 - 2_f64 * &p * &z + &z * &z).pow(0.5_f64) + &z - &p) / (1_f64 - &p)).log();
115
116    let x: Number;
117    if z.abs() > 1e-15_f64 {
118        x = &z / &chi;
119    } else {
120        // handle the undefined quotient case when f=k by directly specifying dual numbers
121        let p_f64 = f64::from(&p);
122
123        x = match &z {
124            Number::F64(_z) => Number::F64(1_f64),
125            Number::Dual(z_) => Number::Dual(Dual {
126                real: 1_f64,
127                dual: &z_.dual * p_f64 * -0.5_f64,
128                vars: Arc::clone(&z_.vars),
129            }),
130            Number::Dual2(z_) => {
131                let (z_cast, p_cast): (Dual2, Dual2) = match &p {
132                    Number::F64(p_) => {
133                        let temp = Dual2::new_from(z_, *p_, vec![]);
134                        z_.to_union_vars(&temp, None)
135                    }
136                    Number::Dual(_) => panic!("Unexpected Dual/Dual2 type crossing in _sabr_x2"),
137                    Number::Dual2(p_) => z_.to_union_vars(p_, None),
138                };
139                let f_z = -0.5_f64 * p_f64;
140                // f_p = 0.0
141                let f_zz = (2_f64 - 3_f64 * p_f64 * p_f64) / 6_f64;
142                let f_zp = -0.5_f64;
143                // f_pp = 0.0
144
145                let mut dual2 = f_z * &z_cast.dual2.clone();
146                dual2 =
147                    dual2 + 0.5_f64 * f_zz * fouter11_(&z_cast.dual.view(), &z_cast.dual.view());
148                // dual2 += 0.5 * f_pp * np.outer(p_.dual, p_.dual)
149                dual2 = dual2
150                    + 0.5_f64
151                        * f_zp
152                        * (fouter11_(&z_cast.dual.view(), &p_cast.dual.view())
153                            + fouter11_(&p_cast.dual.view(), &z_cast.dual.view()));
154                Number::Dual2(Dual2 {
155                    real: 1_f64,
156                    vars: Arc::clone(&z_cast.vars),
157                    dual: &z_cast.dual * p_f64 * -0.5_f64,
158                    dual2,
159                })
160            }
161        };
162    }
163
164    let dx: Option<Number>;
165    match derivative {
166        1 => {
167            if z.abs() > 1e-15_f64 {
168                let x7 = &x1 * &x6;
169                let x8 = &p * &x7;
170                let x9 = (&x1).pow(2_f64);
171                let x10 = (&a).pow(-2_f64);
172                let x11 = (&v).pow(2_f64);
173                let x12 = &b - 1_f64;
174                let x13 = (&x3).pow(-&x12);
175                let x14 = &x10 * &x11 * &x13;
176                let x15 = (&x14 * &x9 - 2_f64 * &x8 + 1_f64).pow(0.5_f64);
177                let x16 = -&p + &x15 + &x7;
178                let x17 = (&x16 / (1_f64 - &p)).log();
179                let x18 = 1_f64 / &x17;
180                let x19 = &x0 * &x6;
181                let x20 = -&x4;
182                let x21 = 1.0 * &x0;
183
184                dx = Some(
185                    &v * &x0 * &x1 * &x18 * &x2 * &x20 * &x5
186                        - &x18 * &x19
187                        - &x7
188                            * (&x0 * &x20 * &x7 - x19
189                                + (1.0_f64 * p * v * &x0 * x2 * x5
190                                    - 0.5_f64 * x0 * x10 * x11 * x12 * x13 * x9
191                                    - x1 * x14 * &x21
192                                    - x20 * x21 * x8)
193                                    / x15)
194                            / (x16 * (&x17).pow(2_f64)),
195                )
196            } else {
197                let dx_dz = _sabr_dx2_dz(&z, &p);
198
199                let y0 = 1_f64 / &k;
200                let y1 = &b / 2_f64 - 0.5_f64;
201                let y2 = &v * &x0 / (&a * (&f * &k).pow(&y1));
202                let dz = -y2 * (y1 * (&f * y0).log() + 1_f64);
203
204                dx = Some(dx_dz * dz);
205            }
206        }
207        2 => {
208            if z.abs() > 1e-15_f64 {
209                let y0 = (&a).pow(2_f64);
210                let y1 = 1_f64 / &y0;
211                let y3 = (&x3).pow(-x4);
212                let y4 = &a * &p;
213                let y6 = &v * &x1;
214                let y7 = &y3 * &y6;
215                let y8 = &b - 1_f64;
216                let y9 = (&x3).pow(-&y8);
217                let y10 =
218                    (&y1 * (&v * &v * &x1 * &x1 * &y9 + &y0 - 2_f64 * &y4 * &y7)).pow(0.5_f64);
219                let y11 = &a * (-&p + &y10) + &y7;
220                let y12 = &a * &y10;
221                let y13 = ((&a * &p - &y12 - &y7) / (&a * (&p - 1_f64))).log();
222                let y14 = &x1 * &y8 - 2_f64;
223                let y15 = -&y14;
224
225                dx = Some(
226                    &v * &y1
227                        * &y3
228                        * (&y11 * &y12 * &y13 * &y15
229                            + &y6 * (&y12 * &y14 * &y3 + &y14 * &y6 * &y9 + &y15 * &y3 * &y4))
230                        / (2_f64 * &f * &y10 * &y11 * (&y13).pow(2_f64)),
231                )
232            } else {
233                let dx_dz = _sabr_dx2_dz(&z, &p);
234                let dz = &v * &x5 * (-(&b - 1_f64) * &x1 + 2_f64) / (2_f64 * &a * &f);
235                dx = Some(dx_dz * dz);
236            }
237        }
238        _ => dx = None,
239    };
240
241    Ok((x, dx))
242}
243
244fn _sabr_dx2_dz(z: &Number, p: &Number) -> Number {
245    let p_f64 = f64::from(p);
246    match z {
247        Number::F64(_) => Number::F64(-p_f64 / 2_f64),
248        Number::Dual(z_) => {
249            let (z_cast, p_cast): (Dual, Dual) = match &p {
250                Number::F64(p_) => {
251                    let temp = Dual::new_from(z_, *p_, vec![]);
252                    z_.to_union_vars(&temp, None)
253                }
254                Number::Dual(p_) => z_.to_union_vars(p_, None),
255                Number::Dual2(_) => panic!("Unexpected Dual/Dual2 type crossing in _sabr_x2"),
256            };
257            let mut dual = -0.5_f64 * &p_cast.dual;
258            dual = dual + (2_f64 - 3_f64 * p_f64 * p_f64) / 6_f64 * &z_cast.dual;
259            Number::Dual(Dual {
260                real: -0.5_f64 * p_f64,
261                vars: Arc::clone(&z_cast.vars),
262                dual,
263            })
264        }
265        Number::Dual2(z_) => {
266            let (z_cast, p_cast): (Dual2, Dual2) = match &p {
267                Number::F64(p_) => {
268                    let temp = Dual2::new_from(z_, *p_, vec![]);
269                    z_.to_union_vars(&temp, None)
270                }
271                Number::Dual(_) => panic!("Unexpected Dual/Dual2 type crossing in _sabr_x2"),
272                Number::Dual2(p_) => z_.to_union_vars(p_, None),
273            };
274            let mut dual = -0.5_f64 * &p_cast.dual;
275            dual = dual + (2_f64 - 3_f64 * p_f64 * p_f64) / 6_f64 * &z_cast.dual;
276            let mut dual2 = (2_f64 - 3_f64 * p_f64 * p_f64) / 6_f64 * &z_cast.dual2;
277            dual2 = dual2 - 0.5_f64 * &p_cast.dual2;
278            dual2 = dual2
279                + p_f64 * (5_f64 - 6_f64 * p_f64 * p_f64) / 8_f64
280                    * fouter11_(&z_cast.dual.view(), &z_cast.dual.view());
281            dual2 = dual2
282                - 0.5_f64
283                    * p_f64
284                    * (fouter11_(&z_cast.dual.view(), &p_cast.dual.view())
285                        + fouter11_(&p_cast.dual.view(), &z_cast.dual.view()));
286            Number::Dual2(Dual2 {
287                real: -0.5_f64 * p_f64,
288                vars: Arc::clone(&z_cast.vars),
289                dual: dual,
290                dual2: dual2,
291            })
292        }
293    }
294}