rateslib/dual/linalg/linalg_dual.rs
1//! Perform linear algebra operations on arrays containing generic data types.
2
3use itertools::Itertools;
4use ndarray::prelude::*;
5use ndarray::Zip;
6use num_traits::identities::Zero;
7use num_traits::Signed;
8use std::cmp::PartialOrd;
9use std::iter::Sum;
10use std::ops::{Div, Mul, Sub};
11
12// Tensor ops
13
14/// Outer product of two 1d-arrays containing generic objects.
15pub fn douter11_<T>(a: &ArrayView1<T>, b: &ArrayView1<T>) -> Array2<T>
16where
17 for<'a> &'a T: Mul<&'a T, Output = T>,
18 T: Sum,
19{
20 Array1::from_vec(
21 a.iter()
22 .cartesian_product(b.iter())
23 .map(|(x, y)| x * y)
24 .collect(),
25 )
26 .into_shape_with_order((a.len(), b.len()))
27 .expect("Pre checked dimensions")
28}
29
30/// Inner product between two 1d-arrays.
31pub fn dmul11_<T>(a: &ArrayView1<T>, b: &ArrayView1<T>) -> T
32where
33 for<'a> &'a T: Mul<&'a T, Output = T>,
34 T: Sum,
35{
36 assert_eq!(a.len(), b.len());
37 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
38}
39
40/// Matrix multiplication between a 2d-array and a 1d-array.
41pub fn dmul21_<T>(a: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
42where
43 for<'a> &'a T: Mul<&'a T, Output = T>,
44 T: Sum,
45{
46 assert_eq!(a.len_of(Axis(1)), b.len_of(Axis(0)));
47 Array1::from_vec(a.axis_iter(Axis(0)).map(|row| dmul11_(&row, b)).collect())
48}
49
50/// Matrix multiplication between two 2d-arrays.
51pub fn dmul22_<T>(a: &ArrayView2<T>, b: &ArrayView2<T>) -> Array2<T>
52where
53 for<'a> &'a T: Mul<&'a T, Output = T>,
54 T: Sum,
55{
56 assert_eq!(a.len_of(Axis(1)), b.len_of(Axis(0)));
57 Array1::<T>::from_vec(
58 a.axis_iter(Axis(0))
59 .cartesian_product(b.axis_iter(Axis(1)))
60 .map(|(row, col)| dmul11_(&row, &col))
61 .collect(),
62 )
63 .into_shape_with_order((a.len_of(Axis(0)), b.len_of(Axis(1))))
64 .expect("Dim are pre-checked")
65}
66
67// Linalg solver
68
69pub(crate) fn argabsmax<T>(a: ArrayView1<T>) -> usize
70where
71 T: Signed + PartialOrd,
72{
73 let vi: (&T, usize) = a
74 .iter()
75 .zip(0..)
76 .max_by(|x, y| x.0.abs().partial_cmp(&y.0.abs()).unwrap())
77 .unwrap();
78 vi.1
79}
80
81// pub(crate) fn argabsmax2<T>(a: ArrayView2<T>) -> (usize, usize)
82// where
83// T: Signed + PartialOrd,
84// {
85// let vi: (&T, usize) = a
86// .iter()
87// .zip(0..)
88// .max_by(|x, y| x.0.abs().partial_cmp(&y.0.abs()).unwrap())
89// .unwrap();
90// let n = a.len_of(Axis(0));
91// (vi.1 / n, vi.1 % n)
92// }
93
94pub(crate) fn row_swap<T>(p: &mut Array2<T>, j: &usize, kr: &usize) {
95 let (mut pt, mut pb) = p.slice_mut(s![.., ..]).split_at(Axis(0), *kr);
96 let (r1, r2) = (pt.row_mut(*j), pb.row_mut(0));
97 Zip::from(r1).and(r2).for_each(std::mem::swap);
98}
99
100// pub(crate) fn col_swap<T>(p: &mut Array2<T>, j: &usize, kc: &usize)
101// {
102// let (mut pl, mut pr) = p.slice_mut(s![.., ..]).split_at(Axis(1), *kc);
103// let (c1, c2) = (pl.column_mut(*j), pr.column_mut(0));
104// Zip::from(c1).and(c2).for_each(std::mem::swap);
105// }
106
107pub(crate) fn el_swap<T>(p: &mut Array1<T>, j: &usize, k: &usize) {
108 let (mut pl, mut pr) = p.slice_mut(s![..]).split_at(Axis(0), *k);
109 std::mem::swap(&mut pl[*j], &mut pr[0]);
110}
111
112// fn partial_pivot_matrix<T>(a: &ArrayView2<T>) -> (Array2<f64>, Array2<f64>, Array2<T>)
113// where
114// T: Signed + Num + PartialOrd + Clone,
115// {
116// // pivot square matrix
117// let n = a.len_of(Axis(0));
118// let mut p: Array2<f64> = Array::eye(n);
119// let q: Array2<f64> = Array::eye(n);
120// let mut pa = a.to_owned();
121// for j in 0..n {
122// let k = argabsmax(pa.slice(s![j.., j])) + j;
123// if j != k {
124// // define row swaps j <-> k (note that k > j by definition)
125// let (mut pt, mut pb) = p.slice_mut(s![.., ..]).split_at(Axis(0), k);
126// let (r1, r2) = (pt.row_mut(j), pb.row_mut(0));
127// Zip::from(r1).and(r2).for_each(std::mem::swap);
128//
129// let (mut pt, mut pb) = pa.slice_mut(s![.., ..]).split_at(Axis(0), k);
130// let (r1, r2) = (pt.row_mut(j), pb.row_mut(0));
131// Zip::from(r1).and(r2).for_each(std::mem::swap);
132// }
133// }
134// (p, q, pa)
135// }
136//
137// fn complete_pivot_matrix<T>(a: &ArrayView2<T>) -> (Array2<f64>, Array2<f64>, Array2<T>)
138// where
139// T: Signed + Num + PartialOrd + Clone,
140// {
141// // pivot square matrix
142// let n = a.len_of(Axis(0));
143// let mut p: Array2<f64> = Array::eye(n);
144// let mut q: Array2<f64> = Array::eye(n);
145// let mut at = a.to_owned();
146//
147// for j in 0..n {
148// // iterate diagonally through
149// let (mut kr, mut kc) = argabsmax2(at.slice(s![j.., j..]));
150// kr += j;
151// kc += j; // align with out scope array indices
152//
153// match (kr, kc) {
154// (kr, kc) if kr > j && kc > j => {
155// row_swap(&mut p, &j, &kr);
156// row_swap(&mut at, &j, &kr);
157// col_swap(&mut q, &j, &kc);
158// col_swap(&mut at, &j, &kc);
159// }
160// (kr, kc) if kr > j && kc == j => {
161// row_swap(&mut p, &j, &kr);
162// row_swap(&mut at, &j, &kr);
163// }
164// (kr, kc) if kr == j && kc > j => {
165// col_swap(&mut q, &j, &kc);
166// col_swap(&mut at, &j, &kc);
167// }
168// _ => {}
169// }
170// }
171// (p, q, at)
172// }
173//
174// fn rook_pivot_matrix<T>(a: &ArrayView2<T>) -> (Array2<f64>, Array2<f64>, Array2<T>)
175// where
176// T: Signed + Num + PartialOrd + Clone,
177// {
178// // Implement a modified Rook Pivot.
179// // If Original is the largest Abs in the row, and it is greater than some
180// // tolerance then use that. This prevents row swapping where the rightmost columns
181// // are zero, which ultimately leads to failure in sparse matrices.
182//
183// // pivot square matrix
184// let n = a.len_of(Axis(0));
185// let mut p: Array2<f64> = Array::eye(n);
186// let mut q: Array2<f64> = Array::eye(n);
187// let mut at = a.to_owned();
188//
189// for j in 0..n {
190// // iterate diagonally through
191// let kr = argabsmax(at.slice(s![j.., j])) + j;
192// let kc = argabsmax(at.slice(s![j, j..])) + j;
193//
194// match (kr, kc) {
195// (kr, kc) if kr > j && kc > j => {
196// if at[[kr, j]].abs() > at[[j, kc]].abs() {
197// row_swap(&mut p, &j, &kr);
198// row_swap(&mut at, &j, &kr);
199// } else {
200// col_swap(&mut q, &j, &kc);
201// col_swap(&mut at, &j, &kc);
202// }
203// }
204// (kr, kc) if kr > j && kc == j => {
205// // MODIFIER as explained:
206// // if !(at[[j, j]].abs() > 1e-8) {
207// row_swap(&mut p, &j, &kr);
208// row_swap(&mut at, &j, &kr);
209// // }
210// }
211// (kr, kc) if kr == j && kc > j => {
212// col_swap(&mut q, &j, &kc);
213// col_swap(&mut at, &j, &kc);
214// }
215// _ => {}
216// }
217// }
218// (p, q, at)
219// }
220//
221// pub enum PivotMethod {
222// Partial,
223// Complete,
224// Rook,
225// }
226
227// pub fn pluq_decomp<T>(
228// a: &ArrayView2<T>,
229// pivot: PivotMethod,
230// ) -> (Array2<f64>, Array2<T>, Array2<T>, Array2<f64>)
231// where
232// T: Signed + Num + PartialOrd + Clone + One + Zero + Sum + for<'a> Div<&'a T, Output = T>,
233// for<'a> &'a T: Mul<&'a T, Output = T> + Sub<T, Output = T>,
234// {
235// let n: usize = a.len_of(Axis(0));
236// let mut l: Array2<T> = Array2::zeros((n, n));
237// let mut u: Array2<T> = Array2::zeros((n, n));
238// let p;
239// let q;
240// let paq;
241// match pivot {
242// PivotMethod::Partial => (p, q, paq) = partial_pivot_matrix(a),
243// PivotMethod::Complete => (p, q, paq) = complete_pivot_matrix(a),
244// PivotMethod::Rook => {
245// (p, q, paq) = rook_pivot_matrix(a);
246// }
247// }
248//
249// let one = T::one();
250// for j in 0..n {
251// l[[j, j]] = one.clone(); // all diagonal entries of L are set to unity
252//
253// for i in 0..j + 1 {
254// // LaTeX: u_{ij} = a_{ij} - \sum_{k=1}^{i-1} u_{kj} l_{ik}
255// let sx = dmul11_(&l.slice(s![i, ..i]), &u.slice(s![..i, j]));
256// u[[i, j]] = &paq[[i, j]] - sx;
257// }
258//
259// for i in j..n {
260// // LaTeX: l_{ij} = \frac{1}{u_{jj}} (a_{ij} - \sum_{k=1}^{j-1} u_{kj} l_{ik})
261// let sy = dmul11_(&l.slice(s![i, ..j]), &u.slice(s![..j, j]));
262// l[[i, j]] = (&paq[[i, j]] - sy) / &u[[j, j]];
263// }
264// }
265// (p, l, u, q)
266// }
267
268// fn dsolve_lower21_<T>(l: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
269// where
270// T: Clone + Sum + Zero,
271// for<'a> &'a T: Sub<&'a T, Output = T> + Mul<&'a T, Output = T> + Div<&'a T, Output = T>
272// {
273// let n: usize = l.len_of(Axis(0));
274// let mut x: Array1<T> = Array::zeros(n);
275// for i in 0..n {
276// let v = &b[i] - &dmul11_(&l.slice(s![i, ..i]), &x.slice(s![..i]));
277// x[i] = &v / &l[[i, i]]
278// }
279// x
280// }
281
282fn dsolve_upper21_<T>(u: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
283where
284 T: Clone + Sum + Zero,
285 for<'a> &'a T: Sub<&'a T, Output = T> + Mul<&'a T, Output = T> + Div<&'a T, Output = T>,
286{
287 let n: usize = u.len_of(Axis(0));
288 let mut x: Array1<T> = Array::zeros(n);
289 for i in (0..n).rev() {
290 let v = &b[i] - &dmul11_(&u.slice(s![i, (i + 1)..]), &x.slice(s![(i + 1)..]));
291 x[i] = &v / &u[[i, i]]
292 }
293 x
294}
295
296fn dsolve21_<T>(a: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
297where
298 T: PartialOrd + Signed + Clone + Zero + Sum,
299 for<'a> &'a T: Sub<&'a T, Output = T> + Mul<&'a T, Output = T> + Div<&'a T, Output = T>,
300{
301 assert!(a.is_square());
302 let n = a.len_of(Axis(0));
303 assert_eq!(b.len_of(Axis(0)), n);
304
305 // a_ and b_ will be pivoted and amended throughout the solution
306 let mut a_ = a.to_owned();
307 let mut b_ = b.to_owned();
308
309 for j in 0..n {
310 let k = argabsmax(a_.slice(s![j.., j])) + j;
311 if j != k {
312 // define row swaps j <-> k (note that k > j by definition)
313 row_swap(&mut a_, &j, &k);
314 el_swap(&mut b_, &j, &k);
315 }
316 // perform reduction on subsequent rows below j
317 for l in (j + 1)..n {
318 let scl = &a_[[l, j]] / &a_[[j, j]];
319 a_[[l, j]] = T::zero();
320 for m in (j + 1)..n {
321 a_[[l, m]] = &a_[[l, m]] - &(&scl * &a_[[j, m]]);
322 }
323 b_[l] = &b_[l] - &(&scl * &b_[j]);
324 }
325 }
326 dsolve_upper21_(&a_.view(), &b_.view())
327}
328
329// fn dsolve_upper_1d<T>(u: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
330// where
331// T: Clone + Sum + Zero + for<'a> Div<&'a T, Output = T>,
332// for<'a> &'a T: Sub<T, Output = T> + Mul<&'a T, Output = T>,
333// {
334// // reverse all dimensions and solve as lower triangular
335// dsolve_lower_1d(&u.slice(s![..;-1, ..;-1]), &b.slice(s![..;-1]))
336// .slice(s![..;-1])
337// .to_owned()
338// }
339
340// fn dsolve21_<T>(a: &ArrayView2<T>, b: &ArrayView1<T>) -> Array1<T>
341// where
342// T: PartialOrd + Signed + Clone + Sum + Zero + for<'a> Div<&'a T, Output = T>,
343// for<'a> &'a T: Mul<&'a f64, Output = T> + Sub<T, Output = T> + Mul<&'a T, Output = T>,
344// for<'a> &'a f64: Mul<&'a T, Output = T>,
345// {
346// let (p, l, u, q) = pluq_decomp::<T>(&a.view(), PivotMethod::Complete);
347// let pb: Array1<T> = fdmul21_(&p.view(), &b.view());
348// let z: Array1<T> = dsolve_lower_1d(&l.view(), &pb.view());
349// let y: Array1<T> = dsolve_upper_1d(&u.view(), &z.view());
350// let x: Array1<T> = fdmul21_(&q.view(), &y.view());
351// x
352// }
353
354/// Solve a linear system of equations, ax = b, using Gaussian elimination and partial pivoting.
355///
356/// - `a` is a 2d-array.
357/// - `b` is a 1d-array.
358/// - `allow_lsq` can be set to `true` if the number of rows in `a` is greater than its number of columns.
359pub fn dsolve<T>(a: &ArrayView2<T>, b: &ArrayView1<T>, allow_lsq: bool) -> Array1<T>
360where
361 T: PartialOrd + Signed + Clone + Sum + Zero,
362 for<'a> &'a T: Sub<&'a T, Output = T> + Mul<&'a T, Output = T> + Div<&'a T, Output = T>,
363{
364 if allow_lsq {
365 let a_ = dmul22_(&a.t(), a);
366 let b_ = dmul21_(&a.t(), b);
367 dsolve21_(&a_.view(), &b_.view())
368 } else {
369 dsolve21_(a, b)
370 }
371}
372
373// UNIT TESTS
374
375//
376
377//
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::dual::dual::{Dual, Vars};
383 use std::sync::Arc;
384
385 // fn is_close(a: &f64, b: &f64, abs_tol: Option<f64>) -> bool {
386 // // used rather than equality for float numbers
387 // (a - b).abs() < abs_tol.unwrap_or(1e-8)
388 // }
389
390 #[test]
391 fn argabsmx_i32() {
392 let a: Array1<i32> = arr1(&[1, 4, 2, -5, 2]);
393 let result = argabsmax(a.view());
394 let expected: usize = 3;
395 assert_eq!(result, expected);
396 }
397
398 // #[test]
399 // fn argabsmx2_i32() {
400 // let a: Array2<i32> = arr2(&[[-1, 2, 100], [-5, -2000, 0], [0, 0, 0]]);
401 // let result = argabsmax2(a.view());
402 // let expected: (usize, usize) = (1, 1);
403 // assert_eq!(result, expected);
404 // }
405
406 #[test]
407 fn argabsmx_dual() {
408 let a: Array1<Dual> = arr1(&[
409 Dual::new(1.0, Vec::new()),
410 Dual::try_new(-2.5, Vec::from(["a".to_string()]), Vec::from([2.0])).unwrap(),
411 ]);
412 let result = argabsmax(a.view());
413 let expected: usize = 1;
414 assert_eq!(result, expected);
415 }
416
417 // #[test]
418 // fn lower_tri_dual() {
419 // let a = arr2(&[
420 // [
421 // Dual::new(1.0, Vec::new()),
422 // Dual::new(0.0, Vec::new()),
423 // ],
424 // [
425 // Dual::new(2.0, Vec::new()),
426 // Dual::new(1.0, Vec::new()),
427 // ],
428 // ]);
429 // let b = arr1(&[
430 // Dual::new(2.0, Vec::new()),
431 // Dual::new(5.0, Vec::new()),
432 // ]);
433 // let x = dsolve_lower21_(&a.view(), &b.view());
434 // let expected_x = arr1(&[
435 // Dual::new(2.0, Vec::new()),
436 // Dual::new(1.0, Vec::new()),
437 // ]);
438 // assert_eq!(x, expected_x);
439 // }
440
441 #[test]
442 fn upper_tri_dual() {
443 let a = arr2(&[
444 [Dual::new(1.0, Vec::new()), Dual::new(2.0, Vec::new())],
445 [Dual::new(0.0, Vec::new()), Dual::new(1.0, Vec::new())],
446 ]);
447 let b = arr1(&[Dual::new(2.0, Vec::new()), Dual::new(5.0, Vec::new())]);
448 let x = dsolve_upper21_(&a.view(), &b.view());
449 let expected_x = arr1(&[Dual::new(-8.0, Vec::new()), Dual::new(5.0, Vec::new())]);
450 assert_eq!(x, expected_x);
451 }
452
453 #[test]
454 fn dsolve_dual() {
455 let a: Array2<Dual> = Array2::eye(2);
456 let b: Array1<Dual> = arr1(&[
457 Dual::new(2.0, vec!["x".to_string()]),
458 Dual::new(5.0, vec!["x".to_string(), "y".to_string()]),
459 ]);
460 let result = dsolve(&a.view(), &b.view(), false);
461 let expected = arr1(&[
462 Dual::new(2.0, vec!["x".to_string()]),
463 Dual::new(5.0, vec!["x".to_string(), "y".to_string()]),
464 ]);
465 assert_eq!(result, expected);
466 assert!(Arc::ptr_eq(&result[0].vars(), &result[1].vars()));
467 }
468
469 #[test]
470 #[should_panic]
471 fn dmul11_p() {
472 dmul11_(&arr1(&[1.0, 2.0]).view(), &arr1(&[1.0]).view());
473 }
474
475 #[test]
476 #[should_panic]
477 fn dmul22_p() {
478 dmul22_(
479 &arr2(&[[1.0, 2.0], [2.0, 3.0]]).view(),
480 &arr2(&[[1.0, 2.0]]).view(),
481 );
482 }
483
484 #[test]
485 #[should_panic]
486 fn dmul21_p() {
487 dmul21_(
488 &arr2(&[[1.0, 2.0], [2.0, 3.0]]).view(),
489 &arr1(&[1.0]).view(),
490 );
491 }
492}