rateslib/curves/interpolation/
intp_flat_forward.rs

1use crate::curves::nodes::NodesTimestamp;
2use crate::curves::CurveInterpolation;
3use crate::dual::Number;
4use bincode::config::legacy;
5use bincode::serde::{decode_from_slice, encode_to_vec};
6use chrono::NaiveDateTime;
7use pyo3::prelude::*;
8use pyo3::types::{PyBytes, PyTuple};
9use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
10use serde::{Deserialize, Serialize};
11use std::cmp::PartialEq;
12
13/// Define flat forward interpolation of nodes.
14#[pyclass(module = "rateslib.rs")]
15#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
16pub struct FlatForwardInterpolator {}
17
18#[pymethods]
19impl FlatForwardInterpolator {
20    #[new]
21    pub fn new() -> Self {
22        FlatForwardInterpolator {}
23    }
24
25    // Pickling
26    pub fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
27        *self = decode_from_slice(state.as_bytes(), legacy()).unwrap().0;
28        Ok(())
29    }
30    pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
31        Ok(PyBytes::new(py, &encode_to_vec(&self, legacy()).unwrap()))
32    }
33    pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
34        Ok(PyTuple::empty(py))
35    }
36}
37
38impl CurveInterpolation for FlatForwardInterpolator {
39    fn interpolated_value(&self, nodes: &NodesTimestamp, date: &NaiveDateTime) -> Number {
40        let x = date.and_utc().timestamp();
41        let index = self.node_index(nodes, x);
42        macro_rules! interp {
43            ($Variant: ident, $indexmap: expr) => {{
44                let (_x1, y1) = $indexmap.get_index(index).unwrap();
45                let (x2, y2) = $indexmap.get_index(index + 1_usize).unwrap();
46                if x >= *x2 {
47                    Number::$Variant(y2.clone())
48                } else {
49                    Number::$Variant(y1.clone())
50                }
51            }};
52        }
53        match nodes {
54            NodesTimestamp::F64(m) => interp!(F64, m),
55            NodesTimestamp::Dual(m) => interp!(Dual, m),
56            NodesTimestamp::Dual2(m) => interp!(Dual2, m),
57        }
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use crate::curves::nodes::Nodes;
65    use crate::scheduling::ndt;
66    use indexmap::IndexMap;
67
68    fn nodes_timestamp_fixture() -> NodesTimestamp {
69        let nodes = Nodes::F64(IndexMap::from_iter(vec![
70            (ndt(2000, 1, 1), 1.0_f64),
71            (ndt(2001, 1, 1), 0.99_f64),
72            (ndt(2002, 1, 1), 0.98_f64),
73        ]));
74        NodesTimestamp::from(nodes)
75    }
76
77    #[test]
78    fn test_flat_forward() {
79        let nts = nodes_timestamp_fixture();
80        let li = FlatForwardInterpolator::new();
81        let result = li.interpolated_value(&nts, &ndt(2000, 7, 1));
82        assert_eq!(result, Number::F64(1.0));
83    }
84
85    #[test]
86    fn test_flat_forward_left_out_of_bounds() {
87        let nts = nodes_timestamp_fixture();
88        let li = FlatForwardInterpolator::new();
89        let result = li.interpolated_value(&nts, &ndt(1999, 7, 1));
90        assert_eq!(result, Number::F64(1.0));
91    }
92
93    #[test]
94    fn test_flat_forward_right_out_of_bounds() {
95        let nts = nodes_timestamp_fixture();
96        let li = FlatForwardInterpolator::new();
97        let result = li.interpolated_value(&nts, &ndt(2005, 7, 1));
98        assert_eq!(result, Number::F64(0.98));
99    }
100
101    #[test]
102    fn test_flat_forward_equals_interval_value() {
103        let nts = nodes_timestamp_fixture();
104        let li = FlatForwardInterpolator::new();
105        let result = li.interpolated_value(&nts, &ndt(2001, 1, 1));
106        assert_eq!(result, Number::F64(0.99));
107    }
108}