1use crate::dual::linalg::linalg_dual::{argabsmax, dmul22_, el_swap, row_swap};
4use itertools::Itertools;
5use ndarray::prelude::*;
6use num_traits::identities::Zero;
7use num_traits::Signed;
8use std::cmp::PartialOrd;
9use std::iter::Sum;
10use std::ops::{Mul, Sub};
11
12pub fn fouter11_(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> Array2<f64> {
14 Array1::from_vec(
15 a.iter()
16 .cartesian_product(b.iter())
17 .map(|(x, y)| x * y)
18 .collect(),
19 )
20 .into_shape_with_order((a.len(), b.len()))
21 .expect("Pre checked dimensions")
22}
23
24pub fn fdmul11_<T>(a: &ArrayView1<f64>, b: &ArrayView1<T>) -> T
30where
31 for<'a> &'a f64: Mul<&'a T, Output = T>,
32 T: Sum,
33{
34 assert_eq!(a.len(), b.len());
35 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
36}
37
38pub fn fdmul21_<T>(a: &ArrayView2<f64>, b: &ArrayView1<T>) -> Array1<T>
42where
43 for<'a> &'a f64: Mul<&'a T, Output = T>,
44 T: Sum,
45{
46 assert_eq!(a.len_of(Axis(1)), b.len_of(Axis(0)));
47 Array1::from_vec(a.axis_iter(Axis(0)).map(|row| fdmul11_(&row, b)).collect())
48}
49
50pub fn dfmul21_<T>(a: &ArrayView2<T>, b: &ArrayView1<f64>) -> Array1<T>
54where
55 for<'a> &'a f64: Mul<&'a T, Output = T>,
56 T: Sum,
57{
58 assert_eq!(a.len_of(Axis(1)), b.len_of(Axis(0)));
59 Array1::from_vec(a.axis_iter(Axis(0)).map(|row| fdmul11_(b, &row)).collect())
60}
61
62pub fn fdmul22_<T>(a: &ArrayView2<f64>, b: &ArrayView2<T>) -> Array2<T>
66where
67 for<'a> &'a f64: Mul<&'a T, Output = T>,
68 T: Sum,
69{
70 assert_eq!(a.len_of(Axis(1)), b.len_of(Axis(0)));
71 Array1::<T>::from_vec(
72 a.axis_iter(Axis(0))
73 .cartesian_product(b.axis_iter(Axis(1)))
74 .map(|(row, col)| fdmul11_(&row, &col))
75 .collect(),
76 )
77 .into_shape_with_order((a.len_of(Axis(0)), b.len_of(Axis(1))))
78 .expect("Dim are pre-checked")
79}
80
81pub fn dfmul22_<T>(a: &ArrayView2<T>, b: &ArrayView2<f64>) -> Array2<T>
85where
86 for<'a> &'a f64: Mul<&'a T, Output = T>,
87 T: Sum,
88{
89 assert_eq!(a.len_of(Axis(1)), b.len_of(Axis(0)));
90 Array1::<T>::from_vec(
91 a.axis_iter(Axis(0))
92 .cartesian_product(b.axis_iter(Axis(1)))
93 .map(|(row, col)| fdmul11_(&col, &row))
94 .collect(),
95 )
96 .into_shape_with_order((a.len_of(Axis(0)), b.len_of(Axis(1))))
97 .expect("Dim are pre-checked")
98}
99
100fn fdsolve_upper21_<T>(u: &ArrayView2<f64>, b: &ArrayView1<T>) -> Array1<T>
101where
102 T: Sum + Zero + Clone,
103 for<'a> &'a f64: Mul<&'a T, Output = T>,
104 for<'a> &'a T: Sub<&'a T, Output = T>,
105{
106 let n: usize = u.len_of(Axis(0));
107 let mut x: Array1<T> = Array::zeros(n);
108 for i in (0..n).rev() {
109 let v = &b[i] - &fdmul11_(&u.slice(s![i, (i + 1)..]), &x.slice(s![(i + 1)..]));
110 x[i] = &(1.0_f64 / &u[[i, i]]) * &v
111 }
112 x
113}
114
115fn fdsolve21_<T>(a: &ArrayView2<f64>, b: &ArrayView1<T>) -> Array1<T>
116where
117 T: PartialOrd + Signed + Clone + Zero + Sum,
118 for<'a> &'a f64: Mul<&'a T, Output = T> + Mul<&'a f64, Output = f64>,
119 for<'a> &'a T: Sub<&'a T, Output = T>,
120{
121 assert!(a.is_square());
122 let n = a.len_of(Axis(0));
123 assert_eq!(b.len_of(Axis(0)), n);
124
125 let mut a_ = a.to_owned();
127 let mut b_ = b.to_owned();
128
129 for j in 0..n {
130 let k = argabsmax(a_.slice(s![j.., j])) + j;
131 if j != k {
132 row_swap(&mut a_, &j, &k);
134 el_swap(&mut b_, &j, &k);
135 }
136 for l in (j + 1)..n {
138 let scl: f64 = a_[[l, j]] / a_[[j, j]];
139 a_[[l, j]] = 0.0_f64;
140 for m in (j + 1)..n {
141 a_[[l, m]] -= scl * a_[[j, m]];
142 }
143 b_[l] = &b_[l] - &(&scl * &b_[j]);
144 }
145 }
146 fdsolve_upper21_(&a_.view(), &b_.view())
147}
148
149pub fn fdsolve<T>(a: &ArrayView2<f64>, b: &ArrayView1<T>, allow_lsq: bool) -> Array1<T>
154where
155 T: PartialOrd + Signed + Clone + Zero + Sum,
156 for<'a> &'a f64: Mul<&'a T, Output = T>,
157 for<'a> &'a T: Sub<&'a T, Output = T>,
158{
159 if allow_lsq {
160 let a_: Array2<f64> = dmul22_(&a.t(), a);
161 let b_: Array1<T> = fdmul21_(&a.t(), b);
162 fdsolve21_(&a_.view(), &b_.view())
163 } else {
164 fdsolve21_(a, b)
165 }
166}
167
168#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::dual::dual::{Dual, Vars};
178 use std::sync::Arc;
179
180 #[test]
186 fn outer_prod() {
187 let a = arr1(&[1.0, 2.0]);
188 let b = arr1(&[2.0, 1.0, 3.0]);
189 let c = fouter11_(&a.view(), &b.view());
190 let result = arr2(&[[2., 1., 3.], [4., 2., 6.]]);
191 assert_eq!(result, c)
192 }
193
194 #[test]
195 fn fdupper_tri_dual() {
196 let a = arr2(&[[1., 2.], [0., 1.]]);
197 let b = arr1(&[Dual::new(2.0, Vec::new()), Dual::new(5.0, Vec::new())]);
198 let x = fdsolve_upper21_(&a.view(), &b.view());
199 let expected_x = arr1(&[Dual::new(-8.0, Vec::new()), Dual::new(5.0, Vec::new())]);
200 assert_eq!(x, expected_x);
201 }
202
203 #[test]
204 fn fdsolve_dual() {
205 let a: Array2<f64> = Array2::eye(2);
206 let b: Array1<Dual> = arr1(&[
207 Dual::new(2.0, vec!["x".to_string()]),
208 Dual::new(5.0, vec!["x".to_string(), "y".to_string()]),
209 ]);
210 let result: Array1<Dual> = fdsolve(&a.view(), &b.view(), false);
211 let expected = arr1(&[
212 Dual::new(2.0, vec!["x".to_string()]),
213 Dual::new(5.0, vec!["x".to_string(), "y".to_string()]),
214 ]);
215 assert_eq!(result, expected);
216 assert!(Arc::ptr_eq(&result[0].vars(), &result[1].vars()));
217 }
218
219 #[test]
220 #[should_panic]
221 fn fdmul11_p() {
222 fdmul11_(&arr1(&[1.0, 2.0]).view(), &arr1(&[1.0]).view());
223 }
224
225 #[test]
226 #[should_panic]
227 fn fdmul22_p() {
228 fdmul22_(
229 &arr2(&[[1.0, 2.0], [2.0, 3.0]]).view(),
230 &arr2(&[[1.0, 2.0]]).view(),
231 );
232 }
233
234 #[test]
235 #[should_panic]
236 fn dfmul22_p() {
237 dfmul22_(
238 &arr2(&[[1.0, 2.0], [2.0, 3.0]]).view(),
239 &arr2(&[[1.0, 2.0]]).view(),
240 );
241 }
242}