rateslib/dual/
dual.rs

1pub use crate::dual::dual_ops::convert::{set_order, set_order_clone};
2pub use crate::dual::dual_ops::math_funcs::MathFuncs;
3pub use crate::dual::dual_ops::numeric_ops::NumberOps;
4use indexmap::set::IndexSet;
5use ndarray::{Array, Array1, Array2, Axis};
6use pyo3::exceptions::PyValueError;
7use pyo3::{pyclass, PyErr};
8use serde::{Deserialize, Serialize};
9use std::cmp::PartialEq;
10use std::sync::Arc;
11
12/// A dual number data type supporting first order derivatives.
13#[pyclass(module = "rateslib.rs")]
14#[derive(Clone, Default, Debug, Deserialize, Serialize)]
15pub struct Dual {
16    pub(crate) real: f64,
17    pub(crate) vars: Arc<IndexSet<String>>,
18    pub(crate) dual: Array1<f64>,
19}
20
21/// A dual number data type supporting second order derivatives.
22#[pyclass(module = "rateslib.rs")]
23#[derive(Clone, Default, Debug, Serialize, Deserialize)]
24pub struct Dual2 {
25    pub(crate) real: f64,
26    pub(crate) vars: Arc<IndexSet<String>>,
27    pub(crate) dual: Array1<f64>,
28    pub(crate) dual2: Array2<f64>,
29}
30
31/// The state of the `vars` measured between two dual number type structs; a LHS relative to a RHS.
32#[derive(Clone, Debug, PartialEq)]
33pub enum VarsRelationship {
34    /// The two structs share the same Arc pointer for their `vars`.
35    ArcEquivalent,
36    /// The structs have the same `vars` in the same order but not a shared Arc pointer.
37    ValueEquivalent,
38    /// The `vars` of the compared RHS is contained within those of the LHS.
39    Superset,
40    /// The `vars` of the calling LHS are contained within those of the RHS.
41    Subset,
42    /// Both the LHS and RHS have different `vars`.
43    Difference,
44}
45
46/// Manages the `vars` of the manifold associated with a dual number.
47pub trait Vars
48where
49    Self: Clone,
50{
51    /// Get a reference to the Arc pointer for the `IndexSet` containing the struct's variables.
52    fn vars(&self) -> &Arc<IndexSet<String>>;
53
54    /// Create a new dual number with `vars` aligned with given new Arc pointer.
55    ///
56    /// This method compares the existing `vars` with the new and reshuffles manifold gradient
57    /// values in memory. For large numbers of variables this is one of the least efficient
58    /// operations relating different dual numbers and should be avoided where possible.
59    fn to_new_vars(
60        &self,
61        arc_vars: &Arc<IndexSet<String>>,
62        state: Option<VarsRelationship>,
63    ) -> Self;
64
65    /// Compare the `vars` on a `Dual` with a given Arc pointer.
66    fn vars_cmp(&self, arc_vars: &Arc<IndexSet<String>>) -> VarsRelationship {
67        if Arc::ptr_eq(self.vars(), arc_vars) {
68            VarsRelationship::ArcEquivalent
69        } else if self.vars().len() == arc_vars.len()
70            && self.vars().iter().zip(arc_vars.iter()).all(|(a, b)| a == b)
71        {
72            VarsRelationship::ValueEquivalent
73        } else if self.vars().len() >= arc_vars.len()
74            && arc_vars.iter().all(|var| self.vars().contains(var))
75        {
76            VarsRelationship::Superset
77        } else if self.vars().len() < arc_vars.len()
78            && self.vars().iter().all(|var| arc_vars.contains(var))
79        {
80            VarsRelationship::Subset
81        } else {
82            VarsRelationship::Difference
83        }
84    }
85    // fn vars_cmp(&self, arc_vars: &Arc<IndexSet<String>>) -> VarsRelationship;
86
87    /// Construct a tuple of 2 `Self` types whose `vars` are linked by an Arc pointer.
88    ///
89    /// Gradient values contained in fields may be shuffled in memory if necessary
90    /// according to the calculated `VarsRelationship`. Do not use `state` directly unless you have
91    /// performed a pre-check.
92    ///
93    /// # Examples
94    ///
95    /// ```rust
96    /// # use rateslib::dual::{Dual, Vars, VarsRelationship};
97    /// let x = Dual::new(1.0, vec!["x".to_string()]);
98    /// let y = Dual::new(1.5, vec!["y".to_string()]);
99    /// let (a, b) = x.to_union_vars(&y, Some(VarsRelationship::Difference));
100    /// // a: <Dual: 1.0, (x, y), [1.0, 0.0]>
101    /// // b: <Dual: 1.5, (x, y), [0.0, 1.0]>
102    /// ```
103    fn to_union_vars(&self, other: &Self, state: Option<VarsRelationship>) -> (Self, Self)
104    where
105        Self: Sized,
106    {
107        let state_ = state.unwrap_or_else(|| self.vars_cmp(other.vars()));
108        match state_ {
109            VarsRelationship::ArcEquivalent => (self.clone(), other.clone()),
110            VarsRelationship::ValueEquivalent => {
111                (self.clone(), other.to_new_vars(self.vars(), Some(state_)))
112            }
113            VarsRelationship::Superset => (
114                self.clone(),
115                other.to_new_vars(self.vars(), Some(VarsRelationship::Subset)),
116            ),
117            VarsRelationship::Subset => {
118                (self.to_new_vars(other.vars(), Some(state_)), other.clone())
119            }
120            VarsRelationship::Difference => self.to_combined_vars(other),
121        }
122    }
123
124    /// Construct a tuple of 2 `Self` types whose `vars` are linked by the explicit union
125    /// of their own variables.
126    ///
127    /// Gradient values contained in fields will be shuffled in memory.
128    fn to_combined_vars(&self, other: &Self) -> (Self, Self)
129    where
130        Self: Sized,
131    {
132        let comb_vars = Arc::new(IndexSet::from_iter(
133            self.vars().union(other.vars()).cloned(),
134        ));
135        (
136            self.to_new_vars(&comb_vars, Some(VarsRelationship::Difference)),
137            other.to_new_vars(&comb_vars, Some(VarsRelationship::Difference)),
138        )
139    }
140
141    /// Compare if two `Dual` structs share the same `vars` by Arc pointer equivalence.
142    ///
143    /// # Examples
144    ///
145    /// ```rust
146    /// # use rateslib::dual::{Dual, Vars};
147    /// let x1 = Dual::new(1.5, vec!["x".to_string()]);
148    /// let x2 = Dual::new(2.5, vec!["x".to_string()]);
149    /// assert_eq!(x1.ptr_eq(&x2), false); // Vars are the same but not a shared Arc pointer
150    /// ```
151    fn ptr_eq(&self, other: &Self) -> bool {
152        Arc::ptr_eq(self.vars(), other.vars())
153    }
154}
155
156impl Vars for Dual {
157    /// Get a reference to the Arc pointer for the `IndexSet` containing the struct's variables.
158    fn vars(&self) -> &Arc<IndexSet<String>> {
159        &self.vars
160    }
161
162    /// Construct a new `Dual` with `vars` set as the given Arc pointer and gradients shuffled in memory.
163    ///
164    /// Examples
165    ///
166    /// ```rust
167    /// # use rateslib::dual::{Dual, Vars};
168    /// let x = Dual::new(1.5, vec!["x".to_string()]);
169    /// let xy = Dual::new(2.5, vec!["x".to_string(), "y".to_string()]);
170    /// let x_y = x.to_new_vars(xy.vars(), None);
171    /// // x_y: <Dual: 1.5, (x, y), [1.0, 0.0]>
172    /// assert_eq!(x_y, Dual::try_new(1.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0]).unwrap());
173    fn to_new_vars(
174        &self,
175        arc_vars: &Arc<IndexSet<String>>,
176        state: Option<VarsRelationship>,
177    ) -> Self {
178        let match_val = state.unwrap_or_else(|| self.vars_cmp(arc_vars));
179        let dual_: Array1<f64> = match match_val {
180            VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
181                self.dual.clone()
182            }
183            _ => {
184                let lookup_or_zero = |v| match self.vars.get_index_of(v) {
185                    Some(idx) => self.dual[idx],
186                    None => 0.0_f64,
187                };
188                Array1::from_vec(arc_vars.iter().map(lookup_or_zero).collect())
189            }
190        };
191        Self {
192            real: self.real,
193            vars: Arc::clone(arc_vars),
194            dual: dual_,
195        }
196    }
197}
198
199impl Vars for Dual2 {
200    /// Get a reference to the Arc pointer for the `IndexSet` containing the struct's variables.
201    fn vars(&self) -> &Arc<IndexSet<String>> {
202        &self.vars
203    }
204
205    /// Construct a new `Dual2` with `vars` set as the given Arc pointer and gradients shuffled in memory.
206    ///
207    /// Examples
208    ///
209    /// ```rust
210    /// # use rateslib::dual::{Dual2, Vars};
211    /// let x = Dual2::new(1.5, vec!["x".to_string()]);
212    /// let xy = Dual2::new(2.5, vec!["x".to_string(), "y".to_string()]);
213    /// let x_y = x.to_new_vars(xy.vars(), None);
214    /// // x_y: <Dual2: 1.5, (x, y), [1.0, 0.0], [[0.0, 0.0], [0.0, 0.0]]>
215    /// assert_eq!(x_y, Dual2::try_new(1.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0], vec![]).unwrap());
216    fn to_new_vars(
217        &self,
218        arc_vars: &Arc<IndexSet<String>>,
219        state: Option<VarsRelationship>,
220    ) -> Self {
221        let dual_: Array1<f64>;
222        let mut dual2_: Array2<f64> = Array2::zeros((arc_vars.len(), arc_vars.len()));
223        let match_val = state.unwrap_or_else(|| self.vars_cmp(arc_vars));
224        match match_val {
225            VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
226                dual_ = self.dual.clone();
227                dual2_.clone_from(&self.dual2);
228            }
229            _ => {
230                let lookup_or_zero = |v| match self.vars.get_index_of(v) {
231                    Some(idx) => self.dual[idx],
232                    None => 0.0_f64,
233                };
234                dual_ = Array1::from_vec(arc_vars.iter().map(lookup_or_zero).collect());
235
236                let indices: Vec<Option<usize>> =
237                    arc_vars.iter().map(|x| self.vars.get_index_of(x)).collect();
238                for (i, row_index) in indices.iter().enumerate() {
239                    match row_index {
240                        Some(row_value) => {
241                            for (j, col_index) in indices.iter().enumerate() {
242                                match col_index {
243                                    Some(col_value) => {
244                                        dual2_[[i, j]] = self.dual2[[*row_value, *col_value]]
245                                    }
246                                    None => {}
247                                }
248                            }
249                        }
250                        None => {}
251                    }
252                }
253            }
254        }
255        Self {
256            real: self.real,
257            vars: Arc::clone(arc_vars),
258            dual: dual_,
259            dual2: dual2_,
260        }
261    }
262}
263
264/// Provides calculations of first order gradients to all, or a set of provided, `vars`.
265pub trait Gradient1: Vars {
266    /// Get a reference to the Array containing the first order gradients.
267    fn dual(&self) -> &Array1<f64>;
268
269    /// Return a set of first order gradients ordered by the given vector.
270    ///
271    /// Duplicate `vars` are dropped before parsing.
272    fn gradient1(&self, vars: Vec<String>) -> Array1<f64> {
273        let arc_vars = Arc::new(IndexSet::from_iter(vars));
274        let state = self.vars_cmp(&arc_vars);
275        match state {
276            VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
277                self.dual().clone()
278            }
279            _ => {
280                let mut dual_ = Array1::<f64>::zeros(arc_vars.len());
281                for (i, index) in arc_vars
282                    .iter()
283                    .map(|x| self.vars().get_index_of(x))
284                    .enumerate()
285                {
286                    if let Some(value) = index {
287                        dual_[i] = self.dual()[value]
288                    }
289                }
290                dual_
291            }
292        }
293    }
294}
295
296impl Gradient1 for Dual {
297    fn dual(&self) -> &Array1<f64> {
298        &self.dual
299    }
300}
301
302impl Gradient1 for Dual2 {
303    fn dual(&self) -> &Array1<f64> {
304        &self.dual
305    }
306}
307
308/// Provides calculations of second order gradients to all, or a set of provided, `vars`.
309pub trait Gradient2: Gradient1 {
310    /// Get a reference to the Array containing the second order gradients.
311    fn dual2(&self) -> &Array2<f64>;
312
313    /// Return a set of first order gradients ordered by the given vector.
314    ///
315    /// Duplicate `vars` are dropped before parsing.
316    fn gradient2(&self, vars: Vec<String>) -> Array2<f64> {
317        let arc_vars = Arc::new(IndexSet::from_iter(vars));
318        let state = self.vars_cmp(&arc_vars);
319        match state {
320            VarsRelationship::ArcEquivalent | VarsRelationship::ValueEquivalent => {
321                2.0_f64 * self.dual2()
322            }
323            _ => {
324                let indices: Vec<Option<usize>> = arc_vars
325                    .iter()
326                    .map(|x| self.vars().get_index_of(x))
327                    .collect();
328                let mut dual2_ = Array::zeros((arc_vars.len(), arc_vars.len()));
329                for (i, row_index) in indices.iter().enumerate() {
330                    for (j, col_index) in indices.iter().enumerate() {
331                        match row_index {
332                            Some(row_value) => match col_index {
333                                Some(col_value) => {
334                                    dual2_[[i, j]] = self.dual2()[[*row_value, *col_value]]
335                                }
336                                None => {}
337                            },
338                            None => {}
339                        }
340                    }
341                }
342                2_f64 * dual2_
343            }
344        }
345    }
346
347    fn gradient1_manifold(&self, vars: Vec<String>) -> Array1<Dual2> {
348        let indices: Vec<Option<usize>> =
349            vars.iter().map(|x| self.vars().get_index_of(x)).collect();
350
351        let default_zero = Dual2::new(0., vars.clone());
352        let mut grad: Array1<Dual2> = Array1::zeros(vars.len());
353        for (i, i_idx) in indices.iter().enumerate() {
354            match i_idx {
355                Some(i_val) => {
356                    let mut dual: Array1<f64> = Array1::zeros(vars.len());
357                    for (j, j_idx) in indices.iter().enumerate() {
358                        match j_idx {
359                            Some(j_val) => dual[j] = self.dual2()[[*i_val, *j_val]] * 2.0,
360                            None => {}
361                        }
362                    }
363                    grad[i] = Dual2 {
364                        real: self.dual()[*i_val],
365                        vars: Arc::clone(&default_zero.vars),
366                        dual2: Array2::zeros((vars.len(), vars.len())),
367                        dual,
368                    };
369                }
370                None => grad[i] = default_zero.clone(),
371            }
372        }
373        grad
374    }
375}
376
377impl Gradient2 for Dual2 {
378    fn dual2(&self) -> &Array2<f64> {
379        &self.dual2
380    }
381}
382
383impl Dual {
384    /// Constructs a new `Dual`.
385    ///
386    /// - `vars` should be **unique**; duplicates will be removed by the `IndexSet`.
387    ///
388    /// Gradient values for each of the provided `vars` is set to 1.0_f64.
389    ///
390    /// # Examples
391    ///
392    /// ```rust
393    /// # use rateslib::dual::Dual;
394    /// let x = Dual::new(2.5, vec!["x".to_string()]);
395    /// // x: <Dual: 2.5, (x), [1.0]>
396    /// ```
397    pub fn new(real: f64, vars: Vec<String>) -> Self {
398        let unique_vars_ = Arc::new(IndexSet::from_iter(vars));
399        Self {
400            real,
401            dual: Array1::ones(unique_vars_.len()),
402            vars: unique_vars_,
403        }
404    }
405
406    /// Constructs a new `Dual`.
407    ///
408    /// - `vars` should be **unique**; duplicates will be removed by the `IndexSet`.
409    /// - `dual` can be empty; if so each gradient with respect to each `vars` is set to 1.0_f64.
410    ///
411    /// `try_new` should be used instead of `new` when gradient values other than 1.0_f64 are to
412    /// be initialised.
413    ///
414    /// # Errors
415    ///
416    /// If the length of `dual` and of `vars` are not the same after parsing.
417    ///
418    /// # Examples
419    ///
420    /// ```rust
421    /// # use rateslib::dual::Dual;
422    /// let x = Dual::try_new(2.5, vec!["x".to_string()], vec![4.2]).unwrap();
423    /// // x: <Dual: 2.5, (x), [4.2]>
424    /// ```
425    pub fn try_new(real: f64, vars: Vec<String>, dual: Vec<f64>) -> Result<Self, PyErr> {
426        let unique_vars_ = Arc::new(IndexSet::from_iter(vars));
427        let dual_ = if dual.is_empty() {
428            Array1::ones(unique_vars_.len())
429        } else {
430            Array1::from_vec(dual)
431        };
432        if unique_vars_.len() != dual_.len() {
433            Err(PyValueError::new_err(
434                "`vars` and `dual` must have the same length.",
435            ))
436        } else {
437            Ok(Self {
438                real,
439                vars: unique_vars_,
440                dual: dual_,
441            })
442        }
443    }
444
445    /// Construct a new `Dual` cloning the `vars` Arc pointer from another.
446    ///
447    /// # Examples
448    ///
449    /// ```rust
450    /// # use rateslib::dual::Dual;
451    /// let x = Dual::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0]).unwrap();
452    /// let y1 = Dual::new_from(&x, 1.5, vec!["y".to_string()]);
453    /// ```
454    ///
455    /// This is semantically the same as:
456    ///
457    /// ```rust
458    /// # use rateslib::dual::{Dual, Vars};
459    /// # let x = Dual::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0]).unwrap();
460    /// # let y1 = Dual::new_from(&x, 1.5, vec!["y".to_string()]);
461    /// let y2 = Dual::new(1.5, vec!["y".to_string()]).to_new_vars(x.vars(), None);
462    /// assert_eq!(y1, y2);
463    /// ```
464    pub fn new_from<T: Vars>(other: &T, real: f64, vars: Vec<String>) -> Self {
465        let new = Self::new(real, vars);
466        new.to_new_vars(other.vars(), None)
467    }
468
469    /// Construct a new `Dual` cloning the `vars` Arc pointer from another.
470    ///
471    /// # Examples
472    ///
473    /// ```rust
474    /// # use rateslib::dual::{Dual, Vars};
475    /// let x = Dual::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0]).unwrap();
476    /// let y1 = Dual::try_new_from(&x, 1.5, vec!["y".to_string()], vec![3.2]).unwrap();
477    /// ```
478    ///
479    /// This is semantically the same as:
480    ///
481    /// ```rust
482    /// # use rateslib::dual::{Dual, Vars};
483    /// # let x = Dual::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0]).unwrap();
484    /// # let y1 = Dual::try_new_from(&x, 1.5, vec!["y".to_string()], vec![3.2]).unwrap();
485    /// let y2 = Dual::try_new(1.5, vec!["y".to_string()], vec![3.2]).unwrap().to_new_vars(x.vars(), None);
486    /// assert_eq!(y1, y2);
487    /// ```
488    pub fn try_new_from<T: Vars>(
489        other: &T,
490        real: f64,
491        vars: Vec<String>,
492        dual: Vec<f64>,
493    ) -> Result<Self, PyErr> {
494        let new = Self::try_new(real, vars, dual)?;
495        Ok(new.to_new_vars(other.vars(), None))
496    }
497
498    /// Construct a new `Dual` cloning the `vars` Arc pointer from another.
499    ///
500    pub fn clone_from<T: Vars>(other: &T, real: f64, dual: Array1<f64>) -> Self {
501        assert_eq!(other.vars().len(), dual.len());
502        Dual {
503            real,
504            vars: Arc::clone(other.vars()),
505            dual,
506        }
507    }
508
509    /// Get the real component value of the struct.
510    pub fn real(&self) -> f64 {
511        self.real
512    }
513}
514
515impl Dual2 {
516    /// Constructs a new `Dual2`.
517    ///
518    /// - `vars` should be **unique**; duplicates will be removed by the `IndexSet`.
519    ///
520    /// Gradient values for each of the provided `vars` is set to 1.0_f64.
521    /// Second order gradient values for each combination of provided `vars` is set
522    /// to 0.0_f64.
523    ///
524    /// # Examples
525    ///
526    /// ```rust
527    /// # use rateslib::dual::Dual2;
528    /// let x = Dual2::new(2.5, vec!["x".to_string()]);
529    /// // x: <Dual2: 2.5, (x), [1.0], [[0.0]]>
530    /// ```
531    pub fn new(real: f64, vars: Vec<String>) -> Self {
532        let unique_vars_ = Arc::new(IndexSet::from_iter(vars));
533        Self {
534            real,
535            dual: Array1::ones(unique_vars_.len()),
536            dual2: Array2::zeros((unique_vars_.len(), unique_vars_.len())),
537            vars: unique_vars_,
538        }
539    }
540
541    /// Constructs a new `Dual2`.
542    ///
543    /// - `vars` should be **unique**; duplicates will be removed by the `IndexSet`.
544    /// - `dual` can be empty; if so each gradient with respect to each `vars` is set to 1.0_f64.
545    /// - `dual2` can be empty; if so each gradient with respect to each `vars` is set to 0.0_f64.
546    ///   Input as a flattened 2d-array in row major order.
547    ///
548    /// # Errors
549    ///
550    /// If the length of `dual` and of `vars` are not the same after parsing.
551    /// If the shape of two dimension `dual2` does not match `vars` after parsing.
552    ///
553    /// # Examples
554    ///
555    /// ```rust
556    /// # use rateslib::dual::Dual2;
557    /// let x = Dual2::try_new(2.5, vec!["x".to_string()], vec![], vec![]).unwrap();
558    /// // x: <Dual2: 2.5, (x), [1.0], [[0.0]]>
559    /// ```
560    pub fn try_new(
561        real: f64,
562        vars: Vec<String>,
563        dual: Vec<f64>,
564        dual2: Vec<f64>,
565    ) -> Result<Self, PyErr> {
566        let unique_vars_ = Arc::new(IndexSet::from_iter(vars));
567        let dual_ = if dual.is_empty() {
568            Array1::ones(unique_vars_.len())
569        } else {
570            Array1::from_vec(dual)
571        };
572        if unique_vars_.len() != dual_.len() {
573            return Err(PyValueError::new_err(
574                "`vars` and `dual` must have the same length.",
575            ));
576        }
577
578        let dual2_ = if dual2.is_empty() {
579            Array2::zeros((unique_vars_.len(), unique_vars_.len()))
580        } else {
581            if dual2.len() != (unique_vars_.len() * unique_vars_.len()) {
582                return Err(PyValueError::new_err(
583                    "`vars` and `dual2` must have compatible lengths.",
584                ));
585            }
586            Array::from_vec(dual2)
587                .into_shape_with_order((unique_vars_.len(), unique_vars_.len()))
588                .expect("Reshaping failed, which should not occur because shape is pre-checked.")
589        };
590        Ok(Self {
591            real,
592            vars: unique_vars_,
593            dual: dual_,
594            dual2: dual2_,
595        })
596    }
597
598    /// Construct a new `Dual2` cloning the `vars` Arc pointer from another.
599    ///
600    /// # Examples
601    ///
602    /// ```rust
603    /// # use rateslib::dual::Dual2;
604    /// let x = Dual2::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0], vec![]).unwrap();
605    /// let y1 = Dual2::new_from(&x, 1.5, vec!["y".to_string()]);
606    /// ```
607    ///
608    /// This is semantically the same as:
609    ///
610    /// ```rust
611    /// # use rateslib::dual::{Dual2, Vars};
612    /// # let x = Dual2::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0], vec![]).unwrap();
613    /// # let y1 = Dual2::new_from(&x, 1.5, vec!["y".to_string()]);
614    /// let y = Dual2::new(1.5, vec!["y".to_string()]).to_new_vars(x.vars(), None);
615    /// ```
616    pub fn new_from<T: Vars>(other: &T, real: f64, vars: Vec<String>) -> Self {
617        let new = Self::new(real, vars);
618        new.to_new_vars(other.vars(), None)
619    }
620
621    /// Construct a new `Dual2` cloning the `vars` Arc pointer from another.
622    ///
623    /// # Examples
624    ///
625    /// ```rust
626    /// # use rateslib::dual::Dual2;
627    /// let x = Dual2::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0], vec![]).unwrap();
628    /// let y1 = Dual2::new_from(&x, 1.5, vec!["y".to_string()]);
629    /// ```
630    ///
631    /// This is semantically the same as:
632    ///
633    /// ```rust
634    /// # use rateslib::dual::{Dual2, Vars};
635    /// # let x = Dual2::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0], vec![]).unwrap();
636    /// # let y1 = Dual2::new_from(&x, 1.5, vec!["y".to_string()]);
637    /// let y2 = Dual2::new(1.5, vec!["y".to_string()]).to_new_vars(x.vars(), None);
638    /// assert_eq!(y1, y2);
639    /// ```
640    pub fn try_new_from<T: Vars>(
641        other: &T,
642        real: f64,
643        vars: Vec<String>,
644        dual: Vec<f64>,
645        dual2: Vec<f64>,
646    ) -> Result<Self, PyErr> {
647        let new = Self::try_new(real, vars, dual, dual2)?;
648        Ok(new.to_new_vars(other.vars(), None))
649    }
650
651    /// Construct a new `Dual2` cloning the `vars` Arc pointer from another.
652    ///
653    pub fn clone_from<T: Vars>(
654        other: &T,
655        real: f64,
656        dual: Array1<f64>,
657        dual2: Array2<f64>,
658    ) -> Self {
659        assert_eq!(other.vars().len(), dual.len());
660        assert_eq!(other.vars().len(), dual2.len_of(Axis(0)));
661        assert_eq!(other.vars().len(), dual2.len_of(Axis(1)));
662        Dual2 {
663            real,
664            vars: Arc::clone(other.vars()),
665            dual,
666            dual2,
667        }
668    }
669
670    /// Get the real component value of the struct.
671    pub fn real(&self) -> f64 {
672        self.real
673    }
674}
675
676// UNIT TESTS
677#[cfg(test)]
678mod tests {
679    use super::*;
680    use crate::dual::dual::Dual2;
681    use std::ops::{Add, Div, Mul, Sub};
682    use std::time::Instant;
683
684    #[test]
685    fn new() {
686        let x = Dual::new(1.0, vec!["a".to_string(), "a".to_string()]);
687        assert_eq!(x.real, 1.0_f64);
688        assert_eq!(
689            *x.vars,
690            IndexSet::<String>::from_iter(vec!["a".to_string()])
691        );
692        assert_eq!(x.dual, Array1::from_vec(vec![1.0_f64]));
693    }
694
695    #[test]
696    fn new_with_dual() {
697        let x = Dual::try_new(1.0, vec!["a".to_string(), "a".to_string()], vec![2.5]).unwrap();
698        assert_eq!(x.real, 1.0_f64);
699        assert_eq!(
700            *x.vars,
701            IndexSet::<String>::from_iter(vec!["a".to_string()])
702        );
703        assert_eq!(x.dual, Array1::from_vec(vec![2.5_f64]));
704    }
705
706    #[test]
707    fn new_len_mismatch() {
708        let result =
709            Dual::try_new(1.0, vec!["a".to_string(), "a".to_string()], vec![1.0, 2.0]).is_err();
710        assert!(result);
711    }
712
713    #[test]
714    fn ptr_eq() {
715        let x = Dual::new(1.0, vec!["a".to_string()]);
716        let y = Dual::new(1.0, vec!["a".to_string()]);
717        assert!(x.ptr_eq(&y) == false);
718    }
719
720    #[test]
721    fn to_new_vars() {
722        let x = Dual::try_new(1.5, vec!["a".to_string(), "b".to_string()], vec![1., 2.]).unwrap();
723        let y = Dual::try_new(2.0, vec!["a".to_string(), "c".to_string()], vec![3., 3.]).unwrap();
724        let z = x.to_new_vars(&y.vars, None);
725        assert_eq!(z.real, 1.5_f64);
726        assert!(y.ptr_eq(&z));
727        assert_eq!(z.dual, Array1::from_vec(vec![1.0, 0.0]));
728        let u = x.to_new_vars(x.vars(), None);
729        assert!(u.ptr_eq(&x))
730    }
731
732    #[test]
733    fn new_from() {
734        let x = Dual::try_new(2.0, vec!["a".to_string(), "b".to_string()], vec![3., 3.]).unwrap();
735        let y = Dual::try_new_from(
736            &x,
737            2.0,
738            vec!["a".to_string(), "c".to_string()],
739            vec![3., 3.],
740        )
741        .unwrap();
742        assert_eq!(y.real, 2.0_f64);
743        assert!(y.ptr_eq(&x));
744        assert_eq!(y.dual, Array1::from_vec(vec![3.0, 0.0]));
745    }
746
747    #[test]
748    fn vars() {
749        let x = Dual::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0]).unwrap();
750        let y = Dual::new(1.5, vec!["y".to_string()]).to_new_vars(x.vars(), None);
751        assert!(x.ptr_eq(&y));
752        assert_eq!(y.dual, Array1::from_vec(vec![0.0, 1.0]));
753    }
754
755    #[test]
756    fn vars_cmp() {
757        let x = Dual::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.0, 0.0]).unwrap();
758        let y = Dual::new(1.5, vec!["y".to_string()]);
759        let y2 = Dual::new(1.5, vec!["y".to_string()]);
760        let z = x.to_new_vars(y.vars(), None);
761        let u = Dual::new(1.5, vec!["u".to_string()]);
762        assert_eq!(x.vars_cmp(y.vars()), VarsRelationship::Superset);
763        assert_eq!(y.vars_cmp(z.vars()), VarsRelationship::ArcEquivalent);
764        assert_eq!(y.vars_cmp(y2.vars()), VarsRelationship::ValueEquivalent);
765        assert_eq!(y.vars_cmp(x.vars()), VarsRelationship::Subset);
766        assert_eq!(y.vars_cmp(u.vars()), VarsRelationship::Difference);
767    }
768
769    #[test]
770    fn default() {
771        let x = Dual::default();
772        assert_eq!(x.real, 0.0_f64);
773        assert_eq!(x.vars.len(), 0_usize);
774        assert_eq!(x.dual, Array1::<f64>::from_vec(vec![]));
775    }
776
777    // OPS TESTS
778
779    #[test]
780    fn unitialised_derivs_eq_1() {
781        let d = Dual::new(2.3, Vec::from([String::from("a"), String::from("b")]));
782        for (_, val) in d.dual.indexed_iter() {
783            assert!(*val == 1.0)
784        }
785    }
786
787    #[test]
788    fn gradient1_no_equiv() {
789        let d1 =
790            Dual::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.1, 2.2]).unwrap();
791        let result = d1.gradient1(vec!["y".to_string(), "z".to_string(), "x".to_string()]);
792        let expected = Array1::from_vec(vec![2.2, 0.0, 1.1]);
793        assert_eq!(result, expected)
794    }
795
796    #[test]
797    fn gradient1_equiv() {
798        let d1 =
799            Dual::try_new(2.5, vec!["x".to_string(), "y".to_string()], vec![1.1, 2.2]).unwrap();
800        let result = d1.gradient1(vec!["x".to_string(), "y".to_string()]);
801        let expected = Array1::from_vec(vec![1.1, 2.2]);
802        assert_eq!(result, expected)
803    }
804
805    // PROFILING
806
807    #[test]
808    fn vars_cmp_profile() {
809        // Setup
810        let vars = 500_usize;
811        let x = Dual::try_new(
812            1.5,
813            (1..=vars).map(|x| x.to_string()).collect(),
814            (1..=vars).map(|x| x as f64).collect(),
815        )
816        .unwrap();
817        let y = Dual::try_new(
818            1.5,
819            (1..=vars).map(|x| x.to_string()).collect(),
820            (1..=vars).map(|x| x as f64).collect(),
821        )
822        .unwrap();
823        let z = Dual::new_from(&x, 1.0, Vec::new());
824        let u = Dual::try_new(
825            1.5,
826            (1..vars).map(|x| x.to_string()).collect(),
827            (1..vars).map(|x| x as f64).collect(),
828        )
829        .unwrap();
830        let s = Dual::try_new(
831            1.5,
832            (0..(vars - 1)).map(|x| x.to_string()).collect(), // 2..Vars+1 13us  0..Vars-1  48ns
833            (1..vars).map(|x| x as f64).collect(),
834        )
835        .unwrap();
836
837        println!("\nProfiling vars_cmp (VarsRelationship::ArcEquivalent):");
838        let now = Instant::now();
839        // Code block to measure.
840        {
841            for _i in 0..100000 {
842                // Arc::ptr_eq(&x.vars, &y.vars);
843                x.vars_cmp(&z.vars);
844            }
845        }
846        let elapsed = now.elapsed();
847        println!("\nElapsed: {:.2?}", elapsed / 100000);
848
849        println!("\nProfiling vars_cmp (VarsRelationship::ValueEquivalent):");
850        let now = Instant::now();
851        // Code block to measure.
852        {
853            for _i in 0..1000 {
854                // Arc::ptr_eq(&x.vars, &y.vars);
855                x.vars_cmp(&y.vars);
856            }
857        }
858        let elapsed = now.elapsed();
859        println!("\nElapsed: {:.2?}", elapsed / 1000);
860
861        println!("\nProfiling vars_cmp (VarsRelationship::Superset):");
862        let now = Instant::now();
863        // Code block to measure.
864        {
865            for _i in 0..1000 {
866                // Arc::ptr_eq(&x.vars, &y.vars);
867                x.vars_cmp(&u.vars);
868            }
869        }
870        let elapsed = now.elapsed();
871        println!("\nElapsed: {:.2?}", elapsed / 1000);
872
873        println!("\nProfiling vars_cmp (VarsRelationship::Different):");
874        let now = Instant::now();
875        // Code block to measure.
876        {
877            for _i in 0..1000 {
878                // Arc::ptr_eq(&x.vars, &y.vars);
879                x.vars_cmp(&s.vars);
880            }
881        }
882        let elapsed = now.elapsed();
883        println!("\nElapsed: {:.2?}", elapsed / 1000);
884    }
885
886    #[test]
887    fn to_union_vars_profile() {
888        // Setup
889        let vars = 500_usize;
890        let x = Dual::try_new(
891            1.5,
892            (1..=vars).map(|x| x.to_string()).collect(),
893            (0..vars).map(|x| x as f64).collect(),
894        )
895        .unwrap();
896        let y = Dual::try_new(
897            1.5,
898            (1..=vars).map(|x| x.to_string()).collect(),
899            (0..vars).map(|x| x as f64).collect(),
900        )
901        .unwrap();
902        let z = Dual::new_from(&x, 1.0, Vec::new());
903        let u = Dual::try_new(
904            1.5,
905            (1..vars).map(|x| x.to_string()).collect(),
906            (1..vars).map(|x| x as f64).collect(),
907        )
908        .unwrap();
909        let s = Dual::try_new(
910            1.5,
911            (0..(vars - 1)).map(|x| x.to_string()).collect(),
912            (0..(vars - 1)).map(|x| x as f64).collect(),
913        )
914        .unwrap();
915
916        println!("\nProfiling to_union_vars (VarsRelationship::ArcEquivalent):");
917        let now = Instant::now();
918        // Code block to measure.
919        {
920            for _i in 0..100000 {
921                // Arc::ptr_eq(&x.vars, &y.vars);
922                x.to_union_vars(&z, None);
923            }
924        }
925        let elapsed = now.elapsed();
926        println!("\nElapsed: {:.2?}", elapsed / 100000);
927
928        println!("\nProfiling to_union_vars (VarsRelationship::ValueEquivalent):");
929        let now = Instant::now();
930        // Code block to measure.
931        {
932            for _i in 0..1000 {
933                // Arc::ptr_eq(&x.vars, &y.vars);
934                x.to_union_vars(&y, None);
935            }
936        }
937        let elapsed = now.elapsed();
938        println!("\nElapsed: {:.2?}", elapsed / 1000);
939
940        println!("\nProfiling to_union_vars (VarsRelationship::Superset):");
941        let now = Instant::now();
942        // Code block to measure.
943        {
944            for _i in 0..100 {
945                // Arc::ptr_eq(&x.vars, &y.vars);
946                x.to_union_vars(&u, None);
947            }
948        }
949        let elapsed = now.elapsed();
950        println!("\nElapsed: {:.2?}", elapsed / 100);
951
952        println!("\nProfiling to_union_vars (VarsRelationship::Different):");
953        let now = Instant::now();
954        // Code block to measure.
955        {
956            for _i in 0..100 {
957                // Arc::ptr_eq(&x.vars, &y.vars);
958                x.to_union_vars(&s, None);
959            }
960        }
961        let elapsed = now.elapsed();
962        println!("\nElapsed: {:.2?}", elapsed / 100);
963    }
964
965    #[test]
966    fn std_ops_ref_profile() {
967        fn four_ops<T>(a: &T, b: &T, c: &T, d: &T) -> T
968        where
969            for<'a> &'a T: Add<&'a T, Output = T>
970                + Sub<&'a T, Output = T>
971                + Div<&'a T, Output = T>
972                + Mul<&'a T, Output = T>,
973        {
974            &(&(a + b) * &(c / d)) - a
975        }
976
977        let vars = 500_usize;
978        let a = Dual::try_new(
979            1.5,
980            (1..=vars).map(|x| x.to_string()).collect(),
981            (0..vars).map(|x| x as f64).collect(),
982        )
983        .unwrap();
984        // let b = Dual::new(
985        //     3.5,
986        //     (2..=(VARS+1)).map(|x| x.to_string()).collect(),
987        //     (0..VARS).map(|x| x as f64).collect(),
988        // );
989        // let c = Dual::new(
990        //     5.5,
991        //     (3..=(VARS+2)).map(|x| x.to_string()).collect(),
992        //     (0..VARS).map(|x| x as f64).collect(),
993        // );
994        // let d = Dual::new(
995        //     6.5,
996        //     (4..=(VARS+3)).map(|x| x.to_string()).collect(),
997        //     (0..VARS).map(|x| x as f64).collect(),
998        // );
999        let b = Dual::try_new_from(
1000            &a,
1001            3.5,
1002            (1..=vars).map(|x| x.to_string()).collect(),
1003            (0..vars).map(|x| x as f64).collect(),
1004        )
1005        .unwrap();
1006        let c = Dual::try_new_from(
1007            &a,
1008            5.5,
1009            (1..=vars).map(|x| x.to_string()).collect(),
1010            (0..vars).map(|x| x as f64).collect(),
1011        )
1012        .unwrap();
1013        let d = Dual::try_new_from(
1014            &a,
1015            6.5,
1016            (1..=vars).map(|x| x.to_string()).collect(),
1017            (0..vars).map(|x| x as f64).collect(),
1018        )
1019        .unwrap();
1020
1021        println!("\nProfiling f64 std ops:");
1022        let now = Instant::now();
1023        // Code block to measure.
1024        {
1025            for _i in 0..1000 {
1026                // Arc::ptr_eq(&x.vars, &y.vars);
1027                let _x = four_ops(&a, &b, &c, &d);
1028            }
1029        }
1030        let elapsed = now.elapsed();
1031        println!("\nElapsed: {:.9?}", elapsed / 1000);
1032    }
1033
1034    // copied from old dual2.rs
1035
1036    use ndarray::arr2;
1037
1038    #[test]
1039    fn clone_arc2() {
1040        let d1 = Dual2::new(20.0, vec!["a".to_string()]);
1041        let d2 = d1.clone();
1042        assert!(Arc::ptr_eq(&d1.vars, &d2.vars))
1043    }
1044
1045    #[test]
1046    fn default_dual2() {
1047        let result = Dual2::default();
1048        let expected = Dual2::new(0.0, Vec::new());
1049        assert_eq!(result, expected);
1050    }
1051
1052    #[test]
1053    fn to_new_ordered_vars2() {
1054        let d1 = Dual2::new(20.0, vec!["a".to_string()]);
1055        let d2 = Dual2::new(20.0, vec!["a".to_string(), "b".to_string()]);
1056        let d3 = d1.to_new_vars(&d2.vars, None);
1057        assert!(Arc::ptr_eq(&d3.vars, &d2.vars));
1058        assert!(d3.dual.len() == 2);
1059        let d4 = d2.to_new_vars(&d1.vars, None);
1060        assert!(Arc::ptr_eq(&d4.vars, &d1.vars));
1061        assert!(d4.dual.len() == 1);
1062    }
1063
1064    #[test]
1065    fn new_dual2() {
1066        Dual2::new(2.3, Vec::from([String::from("a")]));
1067    }
1068
1069    #[test]
1070    fn new_dual_error2() {
1071        assert!(Dual2::try_new(
1072            2.3,
1073            Vec::from([String::from("a"), String::from("b")]),
1074            Vec::from([1.0]),
1075            Vec::new(),
1076        )
1077        .is_err());
1078    }
1079
1080    #[test]
1081    fn new_dual2_error() {
1082        assert!(Dual2::try_new(
1083            2.3,
1084            Vec::from([String::from("a"), String::from("b")]),
1085            Vec::from([1.0, 2.3]),
1086            Vec::from([1.0, 2.4, 3.4]),
1087        )
1088        .is_err());
1089    }
1090
1091    #[test]
1092    fn try_new_from2() {
1093        let x = Dual2::new(1.2, vec!["x".to_string(), "y".to_string()]);
1094        let y = Dual2::try_new_from(&x, 3.2, vec!["y".to_string()], vec![1.9], vec![2.1]).unwrap();
1095        let z = Dual2::try_new(
1096            3.2,
1097            vec!["x".to_string(), "y".to_string()],
1098            vec![0., 1.9],
1099            vec![0., 0., 0., 2.1],
1100        )
1101        .unwrap();
1102        assert_eq!(y, z);
1103    }
1104
1105    #[test]
1106    fn to_new_vars2() {
1107        let d1 = Dual2::new(2.5, vec!["x".to_string()]);
1108        let d2 = Dual2::new(3.5, vec!["x".to_string()]);
1109        let d3 = d1.to_new_vars(d2.vars(), None);
1110        assert!(d3.ptr_eq(&d2));
1111        assert_eq!(d3.real, 2.5);
1112        assert_eq!(d3.dual, Array1::from_vec(vec![1.0]));
1113    }
1114
1115    #[test]
1116    fn gradient2_equivval() {
1117        let d1 = Dual2::try_new(
1118            2.5,
1119            vec!["x".to_string(), "y".to_string()],
1120            vec![2.3, 4.5],
1121            vec![1.0, 2.5, 2.5, 5.0],
1122        )
1123        .unwrap();
1124        let result = d1.gradient2(vec!["x".to_string(), "y".to_string()]);
1125        let expected = arr2(&[[2., 5.], [5., 10.]]);
1126        assert_eq!(result, expected);
1127    }
1128
1129    #[test]
1130    fn gradient2_diffvars2() {
1131        let d1 = Dual2::try_new(
1132            2.5,
1133            vec!["x".to_string(), "y".to_string()],
1134            vec![2.3, 4.5],
1135            vec![1.0, 2.5, 2.5, 5.0],
1136        )
1137        .unwrap();
1138        let result = d1.gradient2(vec!["z".to_string(), "y".to_string()]);
1139        let expected = arr2(&[[0., 0.], [0., 10.]]);
1140        assert_eq!(result, expected);
1141    }
1142
1143    #[test]
1144    fn uninitialised_derivs_eq_one2() {
1145        let d = Dual2::new(2.3, Vec::from([String::from("a"), String::from("b")]));
1146        for (_, val) in d.dual.indexed_iter() {
1147            assert!(*val == 1.0)
1148        }
1149    }
1150
1151    #[test]
1152    fn ops_equiv2() {
1153        let d1 = Dual2::try_new(1.5, vec!["x".to_string()], vec![1.0], vec![0.0]).unwrap();
1154        let d2 = Dual2::try_new(2.5, vec!["x".to_string()], vec![2.0], vec![0.0]).unwrap();
1155        let result = &d1 + &d2;
1156        assert_eq!(
1157            result,
1158            Dual2::try_new(4.0, vec!["x".to_string()], vec![3.0], vec![0.0]).unwrap()
1159        );
1160        let result = &d1 - &d2;
1161        assert_eq!(
1162            result,
1163            Dual2::try_new(-1.0, vec!["x".to_string()], vec![-1.0], vec![0.0]).unwrap()
1164        );
1165    }
1166
1167    #[test]
1168    fn grad_manifold() {
1169        let d1 = Dual2::try_new(
1170            2.0,
1171            vec!["x".to_string(), "y".to_string(), "z".to_string()],
1172            vec![1., 2., 3.],
1173            vec![2., 3., 4., 3., 5., 6., 4., 6., 7.],
1174        )
1175        .unwrap();
1176        let result = d1.gradient1_manifold(vec!["y".to_string(), "z".to_string()]);
1177        assert_eq!(result[0].real, 2.);
1178        assert_eq!(result[1].real, 3.);
1179        assert_eq!(result[0].dual, Array1::from_vec(vec![10., 12.]));
1180        assert_eq!(result[1].dual, Array1::from_vec(vec![12., 14.]));
1181        assert_eq!(result[0].dual2, Array2::<f64>::zeros((2, 2)));
1182        assert_eq!(result[1].dual2, Array2::<f64>::zeros((2, 2)));
1183    }
1184
1185    // #[test]
1186    // #[should_panic]
1187    // fn no_dual_cross(){
1188    //     let a = Dual::new(2.0, Vec::new(), Vec::new());
1189    //     let b = Dual2::new(3.0, Vec::new(), Vec::new(), Vec::new());
1190    //     a + b
1191    // }
1192}