rateslib/curves/interpolation/
intp_flat_forward.rs1use crate::curves::nodes::NodesTimestamp;
2use crate::curves::CurveInterpolation;
3use crate::dual::Number;
4use bincode::config::legacy;
5use bincode::serde::{decode_from_slice, encode_to_vec};
6use chrono::NaiveDateTime;
7use pyo3::prelude::*;
8use pyo3::types::{PyBytes, PyTuple};
9use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
10use serde::{Deserialize, Serialize};
11use std::cmp::PartialEq;
12
13#[pyclass(module = "rateslib.rs")]
15#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
16pub struct FlatForwardInterpolator {}
17
18#[pymethods]
19impl FlatForwardInterpolator {
20 #[new]
21 pub fn new() -> Self {
22 FlatForwardInterpolator {}
23 }
24
25 pub fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
27 *self = decode_from_slice(state.as_bytes(), legacy()).unwrap().0;
28 Ok(())
29 }
30 pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
31 Ok(PyBytes::new(py, &encode_to_vec(&self, legacy()).unwrap()))
32 }
33 pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
34 Ok(PyTuple::empty(py))
35 }
36}
37
38impl CurveInterpolation for FlatForwardInterpolator {
39 fn interpolated_value(&self, nodes: &NodesTimestamp, date: &NaiveDateTime) -> Number {
40 let x = date.and_utc().timestamp();
41 let index = self.node_index(nodes, x);
42 macro_rules! interp {
43 ($Variant: ident, $indexmap: expr) => {{
44 let (_x1, y1) = $indexmap.get_index(index).unwrap();
45 let (x2, y2) = $indexmap.get_index(index + 1_usize).unwrap();
46 if x >= *x2 {
47 Number::$Variant(y2.clone())
48 } else {
49 Number::$Variant(y1.clone())
50 }
51 }};
52 }
53 match nodes {
54 NodesTimestamp::F64(m) => interp!(F64, m),
55 NodesTimestamp::Dual(m) => interp!(Dual, m),
56 NodesTimestamp::Dual2(m) => interp!(Dual2, m),
57 }
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use crate::curves::nodes::Nodes;
65 use crate::scheduling::ndt;
66 use indexmap::IndexMap;
67
68 fn nodes_timestamp_fixture() -> NodesTimestamp {
69 let nodes = Nodes::F64(IndexMap::from_iter(vec![
70 (ndt(2000, 1, 1), 1.0_f64),
71 (ndt(2001, 1, 1), 0.99_f64),
72 (ndt(2002, 1, 1), 0.98_f64),
73 ]));
74 NodesTimestamp::from(nodes)
75 }
76
77 #[test]
78 fn test_flat_forward() {
79 let nts = nodes_timestamp_fixture();
80 let li = FlatForwardInterpolator::new();
81 let result = li.interpolated_value(&nts, &ndt(2000, 7, 1));
82 assert_eq!(result, Number::F64(1.0));
83 }
84
85 #[test]
86 fn test_flat_forward_left_out_of_bounds() {
87 let nts = nodes_timestamp_fixture();
88 let li = FlatForwardInterpolator::new();
89 let result = li.interpolated_value(&nts, &ndt(1999, 7, 1));
90 assert_eq!(result, Number::F64(1.0));
91 }
92
93 #[test]
94 fn test_flat_forward_right_out_of_bounds() {
95 let nts = nodes_timestamp_fixture();
96 let li = FlatForwardInterpolator::new();
97 let result = li.interpolated_value(&nts, &ndt(2005, 7, 1));
98 assert_eq!(result, Number::F64(0.98));
99 }
100
101 #[test]
102 fn test_flat_forward_equals_interval_value() {
103 let nts = nodes_timestamp_fixture();
104 let li = FlatForwardInterpolator::new();
105 let result = li.interpolated_value(&nts, &ndt(2001, 1, 1));
106 assert_eq!(result, Number::F64(0.99));
107 }
108}