rateslib/dual/linalg/
linalg_dual.rs

1//! Perform linear algebra operations on arrays containing generic data types.
2
3use itertools::Itertools;
4use ndarray::prelude::*;
5use ndarray::Zip;
6use num_traits::identities::Zero;
7use num_traits::Signed;
8use std::cmp::PartialOrd;
9use std::iter::Sum;
10use std::ops::{Div, Mul, Sub};
11
12// Tensor ops
13
14/// Outer product of two 1d-arrays containing generic objects.
15pub fn douter11_<T>(a: &ArrayView1<T>, b: &ArrayView1<T>) -> Array2<T>
16where
17    for<'a> &'a T: Mul<&'a T, Output = T>,
18    T: Sum,
19{
20    Array1::from_vec(
21        a.iter()
22            .cartesian_product(b.iter())
23            .map(|(x, y)| x * y)
24            .collect(),
25    )
26    .into_shape_with_order((a.len(), b.len()))
27    .expect("Pre checked dimensions")
28}
29
30/// Inner product between two 1d-arrays.
31pub fn dmul11_<T>(a: &ArrayView1<T>, b: &ArrayView1<T>) -> T
32where
33    for<'a> &'a T: Mul<&'a T, Output = T>,
34    T: Sum,
35{
36    assert_eq!(a.len(), b.len());
37    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
38}
39
40/// Matrix multiplication between a 2d-array and a 1d-array.
41pub fn dmul21_<T>(a: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
42where
43    for<'a> &'a T: Mul<&'a T, Output = T>,
44    T: Sum,
45{
46    assert_eq!(a.len_of(Axis(1)), b.len_of(Axis(0)));
47    Array1::from_vec(a.axis_iter(Axis(0)).map(|row| dmul11_(&row, b)).collect())
48}
49
50/// Matrix multiplication between two 2d-arrays.
51pub fn dmul22_<T>(a: &ArrayView2<T>, b: &ArrayView2<T>) -> Array2<T>
52where
53    for<'a> &'a T: Mul<&'a T, Output = T>,
54    T: Sum,
55{
56    assert_eq!(a.len_of(Axis(1)), b.len_of(Axis(0)));
57    Array1::<T>::from_vec(
58        a.axis_iter(Axis(0))
59            .cartesian_product(b.axis_iter(Axis(1)))
60            .map(|(row, col)| dmul11_(&row, &col))
61            .collect(),
62    )
63    .into_shape_with_order((a.len_of(Axis(0)), b.len_of(Axis(1))))
64    .expect("Dim are pre-checked")
65}
66
67// Linalg solver
68
69pub(crate) fn argabsmax<T>(a: ArrayView1<T>) -> usize
70where
71    T: Signed + PartialOrd,
72{
73    let vi: (&T, usize) = a
74        .iter()
75        .zip(0..)
76        .max_by(|x, y| x.0.abs().partial_cmp(&y.0.abs()).unwrap())
77        .unwrap();
78    vi.1
79}
80
81// pub(crate) fn argabsmax2<T>(a: ArrayView2<T>) -> (usize, usize)
82// where
83//     T: Signed + PartialOrd,
84// {
85//     let vi: (&T, usize) = a
86//         .iter()
87//         .zip(0..)
88//         .max_by(|x, y| x.0.abs().partial_cmp(&y.0.abs()).unwrap())
89//         .unwrap();
90//     let n = a.len_of(Axis(0));
91//     (vi.1 / n, vi.1 % n)
92// }
93
94pub(crate) fn row_swap<T>(p: &mut Array2<T>, j: &usize, kr: &usize) {
95    let (mut pt, mut pb) = p.slice_mut(s![.., ..]).split_at(Axis(0), *kr);
96    let (r1, r2) = (pt.row_mut(*j), pb.row_mut(0));
97    Zip::from(r1).and(r2).for_each(std::mem::swap);
98}
99
100// pub(crate) fn col_swap<T>(p: &mut Array2<T>, j: &usize, kc: &usize)
101// {
102//     let (mut pl, mut pr) = p.slice_mut(s![.., ..]).split_at(Axis(1), *kc);
103//     let (c1, c2) = (pl.column_mut(*j), pr.column_mut(0));
104//     Zip::from(c1).and(c2).for_each(std::mem::swap);
105// }
106
107pub(crate) fn el_swap<T>(p: &mut Array1<T>, j: &usize, k: &usize) {
108    let (mut pl, mut pr) = p.slice_mut(s![..]).split_at(Axis(0), *k);
109    std::mem::swap(&mut pl[*j], &mut pr[0]);
110}
111
112// fn partial_pivot_matrix<T>(a: &ArrayView2<T>) -> (Array2<f64>, Array2<f64>, Array2<T>)
113// where
114//     T: Signed + Num + PartialOrd + Clone,
115// {
116//     // pivot square matrix
117//     let n = a.len_of(Axis(0));
118//     let mut p: Array2<f64> = Array::eye(n);
119//     let q: Array2<f64> = Array::eye(n);
120//     let mut pa = a.to_owned();
121//     for j in 0..n {
122//         let k = argabsmax(pa.slice(s![j.., j])) + j;
123//         if j != k {
124//             // define row swaps j <-> k  (note that k > j by definition)
125//             let (mut pt, mut pb) = p.slice_mut(s![.., ..]).split_at(Axis(0), k);
126//             let (r1, r2) = (pt.row_mut(j), pb.row_mut(0));
127//             Zip::from(r1).and(r2).for_each(std::mem::swap);
128//
129//             let (mut pt, mut pb) = pa.slice_mut(s![.., ..]).split_at(Axis(0), k);
130//             let (r1, r2) = (pt.row_mut(j), pb.row_mut(0));
131//             Zip::from(r1).and(r2).for_each(std::mem::swap);
132//         }
133//     }
134//     (p, q, pa)
135// }
136//
137// fn complete_pivot_matrix<T>(a: &ArrayView2<T>) -> (Array2<f64>, Array2<f64>, Array2<T>)
138// where
139//     T: Signed + Num + PartialOrd + Clone,
140// {
141//     // pivot square matrix
142//     let n = a.len_of(Axis(0));
143//     let mut p: Array2<f64> = Array::eye(n);
144//     let mut q: Array2<f64> = Array::eye(n);
145//     let mut at = a.to_owned();
146//
147//     for j in 0..n {
148//         // iterate diagonally through
149//         let (mut kr, mut kc) = argabsmax2(at.slice(s![j.., j..]));
150//         kr += j;
151//         kc += j; // align with out scope array indices
152//
153//         match (kr, kc) {
154//             (kr, kc) if kr > j && kc > j => {
155//                 row_swap(&mut p, &j, &kr);
156//                 row_swap(&mut at, &j, &kr);
157//                 col_swap(&mut q, &j, &kc);
158//                 col_swap(&mut at, &j, &kc);
159//             }
160//             (kr, kc) if kr > j && kc == j => {
161//                 row_swap(&mut p, &j, &kr);
162//                 row_swap(&mut at, &j, &kr);
163//             }
164//             (kr, kc) if kr == j && kc > j => {
165//                 col_swap(&mut q, &j, &kc);
166//                 col_swap(&mut at, &j, &kc);
167//             }
168//             _ => {}
169//         }
170//     }
171//     (p, q, at)
172// }
173//
174// fn rook_pivot_matrix<T>(a: &ArrayView2<T>) -> (Array2<f64>, Array2<f64>, Array2<T>)
175// where
176//     T: Signed + Num + PartialOrd + Clone,
177// {
178//     // Implement a modified Rook Pivot.
179//     // If Original is the largest Abs in the row, and it is greater than some
180//     // tolerance then use that. This prevents row swapping where the rightmost columns
181//     // are zero, which ultimately leads to failure in sparse matrices.
182//
183//     // pivot square matrix
184//     let n = a.len_of(Axis(0));
185//     let mut p: Array2<f64> = Array::eye(n);
186//     let mut q: Array2<f64> = Array::eye(n);
187//     let mut at = a.to_owned();
188//
189//     for j in 0..n {
190//         // iterate diagonally through
191//         let kr = argabsmax(at.slice(s![j.., j])) + j;
192//         let kc = argabsmax(at.slice(s![j, j..])) + j;
193//
194//         match (kr, kc) {
195//             (kr, kc) if kr > j && kc > j => {
196//                 if at[[kr, j]].abs() > at[[j, kc]].abs() {
197//                     row_swap(&mut p, &j, &kr);
198//                     row_swap(&mut at, &j, &kr);
199//                 } else {
200//                     col_swap(&mut q, &j, &kc);
201//                     col_swap(&mut at, &j, &kc);
202//                 }
203//             }
204//             (kr, kc) if kr > j && kc == j => {
205//                 // MODIFIER as explained:
206//                 // if !(at[[j, j]].abs() > 1e-8) {
207//                     row_swap(&mut p, &j, &kr);
208//                     row_swap(&mut at, &j, &kr);
209//                 // }
210//             }
211//             (kr, kc) if kr == j && kc > j => {
212//                 col_swap(&mut q, &j, &kc);
213//                 col_swap(&mut at, &j, &kc);
214//             }
215//             _ => {}
216//         }
217//     }
218//     (p, q, at)
219// }
220//
221// pub enum PivotMethod {
222//     Partial,
223//     Complete,
224//     Rook,
225// }
226
227// pub fn pluq_decomp<T>(
228//     a: &ArrayView2<T>,
229//     pivot: PivotMethod,
230// ) -> (Array2<f64>, Array2<T>, Array2<T>, Array2<f64>)
231// where
232//     T: Signed + Num + PartialOrd + Clone + One + Zero + Sum + for<'a> Div<&'a T, Output = T>,
233//     for<'a> &'a T: Mul<&'a T, Output = T> + Sub<T, Output = T>,
234// {
235//     let n: usize = a.len_of(Axis(0));
236//     let mut l: Array2<T> = Array2::zeros((n, n));
237//     let mut u: Array2<T> = Array2::zeros((n, n));
238//     let p;
239//     let q;
240//     let paq;
241//     match pivot {
242//         PivotMethod::Partial => (p, q, paq) = partial_pivot_matrix(a),
243//         PivotMethod::Complete => (p, q, paq) = complete_pivot_matrix(a),
244//         PivotMethod::Rook => {
245//             (p, q, paq) = rook_pivot_matrix(a);
246//         }
247//     }
248//
249//     let one = T::one();
250//     for j in 0..n {
251//         l[[j, j]] = one.clone(); // all diagonal entries of L are set to unity
252//
253//         for i in 0..j + 1 {
254//             // LaTeX: u_{ij} = a_{ij} - \sum_{k=1}^{i-1} u_{kj} l_{ik}
255//             let sx = dmul11_(&l.slice(s![i, ..i]), &u.slice(s![..i, j]));
256//             u[[i, j]] = &paq[[i, j]] - sx;
257//         }
258//
259//         for i in j..n {
260//             // LaTeX: l_{ij} = \frac{1}{u_{jj}} (a_{ij} - \sum_{k=1}^{j-1} u_{kj} l_{ik})
261//             let sy = dmul11_(&l.slice(s![i, ..j]), &u.slice(s![..j, j]));
262//             l[[i, j]] = (&paq[[i, j]] - sy) / &u[[j, j]];
263//         }
264//     }
265//     (p, l, u, q)
266// }
267
268// fn dsolve_lower21_<T>(l: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
269// where
270//     T: Clone + Sum + Zero,
271//     for<'a> &'a T: Sub<&'a T, Output = T> + Mul<&'a T, Output = T> + Div<&'a T, Output = T>
272// {
273//     let n: usize = l.len_of(Axis(0));
274//     let mut x: Array1<T> = Array::zeros(n);
275//     for i in 0..n {
276//         let v = &b[i] - &dmul11_(&l.slice(s![i, ..i]), &x.slice(s![..i]));
277//         x[i] = &v / &l[[i, i]]
278//     }
279//     x
280// }
281
282fn dsolve_upper21_<T>(u: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
283where
284    T: Clone + Sum + Zero,
285    for<'a> &'a T: Sub<&'a T, Output = T> + Mul<&'a T, Output = T> + Div<&'a T, Output = T>,
286{
287    let n: usize = u.len_of(Axis(0));
288    let mut x: Array1<T> = Array::zeros(n);
289    for i in (0..n).rev() {
290        let v = &b[i] - &dmul11_(&u.slice(s![i, (i + 1)..]), &x.slice(s![(i + 1)..]));
291        x[i] = &v / &u[[i, i]]
292    }
293    x
294}
295
296fn dsolve21_<T>(a: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
297where
298    T: PartialOrd + Signed + Clone + Zero + Sum,
299    for<'a> &'a T: Sub<&'a T, Output = T> + Mul<&'a T, Output = T> + Div<&'a T, Output = T>,
300{
301    assert!(a.is_square());
302    let n = a.len_of(Axis(0));
303    assert_eq!(b.len_of(Axis(0)), n);
304
305    // a_ and b_ will be pivoted and amended throughout the solution
306    let mut a_ = a.to_owned();
307    let mut b_ = b.to_owned();
308
309    for j in 0..n {
310        let k = argabsmax(a_.slice(s![j.., j])) + j;
311        if j != k {
312            // define row swaps j <-> k  (note that k > j by definition)
313            row_swap(&mut a_, &j, &k);
314            el_swap(&mut b_, &j, &k);
315        }
316        // perform reduction on subsequent rows below j
317        for l in (j + 1)..n {
318            let scl = &a_[[l, j]] / &a_[[j, j]];
319            a_[[l, j]] = T::zero();
320            for m in (j + 1)..n {
321                a_[[l, m]] = &a_[[l, m]] - &(&scl * &a_[[j, m]]);
322            }
323            b_[l] = &b_[l] - &(&scl * &b_[j]);
324        }
325    }
326    dsolve_upper21_(&a_.view(), &b_.view())
327}
328
329// fn dsolve_upper_1d<T>(u: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
330// where
331//     T: Clone + Sum + Zero + for<'a> Div<&'a T, Output = T>,
332//     for<'a> &'a T: Sub<T, Output = T> + Mul<&'a T, Output = T>,
333// {
334//     // reverse all dimensions and solve as lower triangular
335//     dsolve_lower_1d(&u.slice(s![..;-1, ..;-1]), &b.slice(s![..;-1]))
336//         .slice(s![..;-1])
337//         .to_owned()
338// }
339
340// fn dsolve21_<T>(a: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
341// where
342//     T: PartialOrd + Signed + Clone + Sum + Zero + for<'a> Div<&'a T, Output = T>,
343//     for<'a> &'a T: Mul<&'a f64, Output = T> + Sub<T, Output = T> + Mul<&'a T, Output = T>,
344//     for<'a> &'a f64: Mul<&'a T, Output = T>,
345// {
346//     let (p, l, u, q) = pluq_decomp::<T>(&a.view(), PivotMethod::Complete);
347//     let pb: Array1<T> = fdmul21_(&p.view(), &b.view());
348//     let z: Array1<T> = dsolve_lower_1d(&l.view(), &pb.view());
349//     let y: Array1<T> = dsolve_upper_1d(&u.view(), &z.view());
350//     let x: Array1<T> = fdmul21_(&q.view(), &y.view());
351//     x
352// }
353
354/// Solve a linear system of equations, ax = b, using Gaussian elimination and partial pivoting.
355///
356/// - `a` is a 2d-array.
357/// - `b` is a 1d-array.
358/// - `allow_lsq` can be set to `true` if the number of rows in `a` is greater than its number of columns.
359pub fn dsolve<T>(a: &ArrayView2<T>, b: &ArrayView1<T>, allow_lsq: bool) -> Array1<T>
360where
361    T: PartialOrd + Signed + Clone + Sum + Zero,
362    for<'a> &'a T: Sub<&'a T, Output = T> + Mul<&'a T, Output = T> + Div<&'a T, Output = T>,
363{
364    if allow_lsq {
365        let a_ = dmul22_(&a.t(), a);
366        let b_ = dmul21_(&a.t(), b);
367        dsolve21_(&a_.view(), &b_.view())
368    } else {
369        dsolve21_(a, b)
370    }
371}
372
373// UNIT TESTS
374
375//
376
377//
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::dual::dual::{Dual, Vars};
383    use std::sync::Arc;
384
385    // fn is_close(a: &f64, b: &f64, abs_tol: Option<f64>) -> bool {
386    //     // used rather than equality for float numbers
387    //     (a - b).abs() < abs_tol.unwrap_or(1e-8)
388    // }
389
390    #[test]
391    fn argabsmx_i32() {
392        let a: Array1<i32> = arr1(&[1, 4, 2, -5, 2]);
393        let result = argabsmax(a.view());
394        let expected: usize = 3;
395        assert_eq!(result, expected);
396    }
397
398    //     #[test]
399    //     fn argabsmx2_i32() {
400    //         let a: Array2<i32> = arr2(&[[-1, 2, 100], [-5, -2000, 0], [0, 0, 0]]);
401    //         let result = argabsmax2(a.view());
402    //         let expected: (usize, usize) = (1, 1);
403    //         assert_eq!(result, expected);
404    //     }
405
406    #[test]
407    fn argabsmx_dual() {
408        let a: Array1<Dual> = arr1(&[
409            Dual::new(1.0, Vec::new()),
410            Dual::try_new(-2.5, Vec::from(["a".to_string()]), Vec::from([2.0])).unwrap(),
411        ]);
412        let result = argabsmax(a.view());
413        let expected: usize = 1;
414        assert_eq!(result, expected);
415    }
416
417    //     #[test]
418    //     fn lower_tri_dual() {
419    //         let a = arr2(&[
420    //             [
421    //                 Dual::new(1.0, Vec::new()),
422    //                 Dual::new(0.0, Vec::new()),
423    //             ],
424    //             [
425    //                 Dual::new(2.0, Vec::new()),
426    //                 Dual::new(1.0, Vec::new()),
427    //             ],
428    //         ]);
429    //         let b = arr1(&[
430    //             Dual::new(2.0, Vec::new()),
431    //             Dual::new(5.0, Vec::new()),
432    //         ]);
433    //         let x = dsolve_lower21_(&a.view(), &b.view());
434    //         let expected_x = arr1(&[
435    //             Dual::new(2.0, Vec::new()),
436    //             Dual::new(1.0, Vec::new()),
437    //         ]);
438    //         assert_eq!(x, expected_x);
439    //     }
440
441    #[test]
442    fn upper_tri_dual() {
443        let a = arr2(&[
444            [Dual::new(1.0, Vec::new()), Dual::new(2.0, Vec::new())],
445            [Dual::new(0.0, Vec::new()), Dual::new(1.0, Vec::new())],
446        ]);
447        let b = arr1(&[Dual::new(2.0, Vec::new()), Dual::new(5.0, Vec::new())]);
448        let x = dsolve_upper21_(&a.view(), &b.view());
449        let expected_x = arr1(&[Dual::new(-8.0, Vec::new()), Dual::new(5.0, Vec::new())]);
450        assert_eq!(x, expected_x);
451    }
452
453    #[test]
454    fn dsolve_dual() {
455        let a: Array2<Dual> = Array2::eye(2);
456        let b: Array1<Dual> = arr1(&[
457            Dual::new(2.0, vec!["x".to_string()]),
458            Dual::new(5.0, vec!["x".to_string(), "y".to_string()]),
459        ]);
460        let result = dsolve(&a.view(), &b.view(), false);
461        let expected = arr1(&[
462            Dual::new(2.0, vec!["x".to_string()]),
463            Dual::new(5.0, vec!["x".to_string(), "y".to_string()]),
464        ]);
465        assert_eq!(result, expected);
466        assert!(Arc::ptr_eq(&result[0].vars(), &result[1].vars()));
467    }
468
469    #[test]
470    #[should_panic]
471    fn dmul11_p() {
472        dmul11_(&arr1(&[1.0, 2.0]).view(), &arr1(&[1.0]).view());
473    }
474
475    #[test]
476    #[should_panic]
477    fn dmul22_p() {
478        dmul22_(
479            &arr2(&[[1.0, 2.0], [2.0, 3.0]]).view(),
480            &arr2(&[[1.0, 2.0]]).view(),
481        );
482    }
483
484    #[test]
485    #[should_panic]
486    fn dmul21_p() {
487        dmul21_(
488            &arr2(&[[1.0, 2.0], [2.0, 3.0]]).view(),
489            &arr1(&[1.0]).view(),
490        );
491    }
492}