rateslib/curves/interpolation/
intp_linear.rs

1use crate::curves::interpolation::utils::linear_interp;
2use crate::curves::nodes::NodesTimestamp;
3use crate::curves::CurveInterpolation;
4use crate::dual::Number;
5use bincode::config::legacy;
6use bincode::serde::{decode_from_slice, encode_to_vec};
7use chrono::NaiveDateTime;
8use pyo3::prelude::*;
9use pyo3::types::{PyBytes, PyTuple};
10use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
11use serde::{Deserialize, Serialize};
12use std::cmp::PartialEq;
13
14/// Define linear interpolation of nodes.
15#[pyclass(module = "rateslib.rs")]
16#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
17pub struct LinearInterpolator {}
18
19#[pymethods]
20impl LinearInterpolator {
21    #[new]
22    pub fn new() -> Self {
23        LinearInterpolator {}
24    }
25
26    // Pickling
27    pub fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
28        *self = decode_from_slice(state.as_bytes(), legacy()).unwrap().0;
29        Ok(())
30    }
31    pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
32        Ok(PyBytes::new(py, &encode_to_vec(&self, legacy()).unwrap()))
33    }
34    pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
35        Ok(PyTuple::empty(py))
36    }
37}
38
39impl CurveInterpolation for LinearInterpolator {
40    fn interpolated_value(&self, nodes: &NodesTimestamp, date: &NaiveDateTime) -> Number {
41        let x = date.and_utc().timestamp();
42        let index = self.node_index(nodes, x);
43
44        macro_rules! interp {
45            ($Variant: ident, $indexmap: expr) => {{
46                let (x1, y1) = $indexmap.get_index(index).unwrap();
47                let (x2, y2) = $indexmap.get_index(index + 1_usize).unwrap();
48                Number::$Variant(linear_interp(*x1 as f64, y1, *x2 as f64, y2, x as f64))
49            }};
50        }
51        match nodes {
52            NodesTimestamp::F64(m) => interp!(F64, m),
53            NodesTimestamp::Dual(m) => interp!(Dual, m),
54            NodesTimestamp::Dual2(m) => interp!(Dual2, m),
55        }
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use crate::curves::nodes::Nodes;
63    use crate::scheduling::ndt;
64    use indexmap::IndexMap;
65
66    fn nodes_timestamp_fixture() -> NodesTimestamp {
67        let nodes = Nodes::F64(IndexMap::from_iter(vec![
68            (ndt(2000, 1, 1), 1.0_f64),
69            (ndt(2001, 1, 1), 0.99_f64),
70            (ndt(2002, 1, 1), 0.98_f64),
71        ]));
72        NodesTimestamp::from(nodes)
73    }
74
75    #[test]
76    fn test_linear() {
77        let nts = nodes_timestamp_fixture();
78        let li = LinearInterpolator::new();
79        let result = li.interpolated_value(&nts, &ndt(2000, 7, 1));
80        // expected = 1.0 + (182 / 366) * (0.99 - 1.0) = 0.995027
81        assert_eq!(result, Number::F64(0.9950273224043715));
82    }
83}