rateslib/curves/
curve_py.rs

1//! Wrapper module to export Rust curve data types to Python using pyo3 bindings.
2
3use crate::curves::nodes::{Nodes, NodesTimestamp};
4use crate::curves::{
5    CurveDF, CurveInterpolation, FlatBackwardInterpolator, FlatForwardInterpolator,
6    LinearInterpolator, LinearZeroRateInterpolator, LogLinearInterpolator, Modifier,
7    NullInterpolator,
8};
9use crate::dual::{get_variable_tags, set_order, ADOrder, Dual, Dual2, Number};
10use crate::json::json_py::DeserializedObj;
11use crate::json::JSON;
12use crate::scheduling::{Calendar, Convention};
13use bincode::config::legacy;
14use bincode::serde::{decode_from_slice, encode_to_vec};
15use chrono::NaiveDateTime;
16use indexmap::IndexMap;
17use pyo3::exceptions::PyValueError;
18use pyo3::prelude::*;
19use pyo3::types::PyBytes;
20use serde::{Deserialize, Serialize};
21
22/// Interpolation
23#[derive(Debug, Clone, PartialEq, FromPyObject, Deserialize, Serialize, IntoPyObject)]
24pub(crate) enum CurveInterpolator {
25    LogLinear(LogLinearInterpolator),
26    Linear(LinearInterpolator),
27    LinearZeroRate(LinearZeroRateInterpolator),
28    FlatForward(FlatForwardInterpolator),
29    FlatBackward(FlatBackwardInterpolator),
30    Null(NullInterpolator),
31}
32
33// // removed upgrading to pyo3 0.23, see https://pyo3.rs/v0.23.0/migration#intopyobject-and-intopyobjectref-derive-macros
34// impl IntoPy<PyObject> for CurveInterpolator {
35//     fn into_py(self, py: Python<'_>) -> PyObject {
36//         macro_rules! into_py {
37//             ($obj: ident) => {
38//                 Py::new(py, $obj).unwrap().to_object(py)
39//             };
40//         }
41//
42//         match self {
43//             CurveInterpolator::LogLinear(i) => into_py!(i),
44//             CurveInterpolator::Linear(i) => into_py!(i),
45//             CurveInterpolator::LinearZeroRate(i) => into_py!(i),
46//             CurveInterpolator::FlatForward(i) => into_py!(i),
47//             CurveInterpolator::FlatBackward(i) => into_py!(i),
48//             CurveInterpolator::Null(i) => into_py!(i),
49//         }
50//     }
51// }
52
53impl CurveInterpolation for CurveInterpolator {
54    fn interpolated_value(&self, nodes: &NodesTimestamp, date: &NaiveDateTime) -> Number {
55        match self {
56            CurveInterpolator::LogLinear(i) => i.interpolated_value(nodes, date),
57            CurveInterpolator::Linear(i) => i.interpolated_value(nodes, date),
58            CurveInterpolator::LinearZeroRate(i) => i.interpolated_value(nodes, date),
59            CurveInterpolator::FlatBackward(i) => i.interpolated_value(nodes, date),
60            CurveInterpolator::FlatForward(i) => i.interpolated_value(nodes, date),
61            CurveInterpolator::Null(i) => i.interpolated_value(nodes, date),
62        }
63    }
64}
65
66#[pyclass(module = "rateslib.rs")]
67#[derive(Clone, Deserialize, Serialize)]
68pub(crate) struct Curve {
69    inner: CurveDF<CurveInterpolator, Calendar>,
70}
71
72#[pymethods]
73impl Curve {
74    #[new]
75    #[pyo3(signature = (nodes, interpolator, ad, id, convention, modifier, calendar, index_base=None))]
76    fn new_py(
77        nodes: IndexMap<NaiveDateTime, Number>,
78        interpolator: CurveInterpolator,
79        ad: ADOrder,
80        id: String,
81        convention: Convention,
82        modifier: Modifier,
83        calendar: Calendar,
84        index_base: Option<f64>,
85    ) -> PyResult<Self> {
86        let nodes_ = nodes_into_order(nodes, ad, &id);
87        let inner = CurveDF::try_new(
88            nodes_,
89            interpolator,
90            &id,
91            convention,
92            modifier,
93            index_base,
94            calendar,
95        )?;
96        Ok(Self { inner })
97    }
98
99    #[getter]
100    fn id(&self) -> String {
101        self.inner.id.clone()
102    }
103
104    #[getter]
105    fn nodes(&self) -> IndexMap<NaiveDateTime, Number> {
106        let nodes = Nodes::from(self.inner.nodes.clone());
107        match nodes {
108            Nodes::F64(i) => IndexMap::from_iter(i.into_iter().map(|(k, v)| (k, Number::F64(v)))),
109            Nodes::Dual(i) => IndexMap::from_iter(i.into_iter().map(|(k, v)| (k, Number::Dual(v)))),
110            Nodes::Dual2(i) => {
111                IndexMap::from_iter(i.into_iter().map(|(k, v)| (k, Number::Dual2(v))))
112            }
113        }
114    }
115
116    #[getter]
117    fn ad(&self) -> ADOrder {
118        self.inner.ad()
119    }
120
121    #[getter]
122    fn interpolation(&self) -> String {
123        match self.inner.interpolator {
124            CurveInterpolator::Linear(_) => "linear".to_string(),
125            CurveInterpolator::LogLinear(_) => "log_linear".to_string(),
126            CurveInterpolator::LinearZeroRate(_) => "linear_zero_rate".to_string(),
127            CurveInterpolator::FlatForward(_) => "flat_forward".to_string(),
128            CurveInterpolator::FlatBackward(_) => "flat_backward".to_string(),
129            CurveInterpolator::Null(_) => "null".to_string(),
130        }
131    }
132
133    #[getter]
134    fn convention(&self) -> Convention {
135        self.inner.convention
136    }
137
138    #[getter]
139    fn modifier(&self) -> Modifier {
140        self.inner.modifier
141    }
142
143    #[pyo3(name = "index_value")]
144    fn index_value_py(&self, date: NaiveDateTime) -> PyResult<Number> {
145        self.inner.index_value(&date)
146    }
147
148    fn set_ad_order(&mut self, ad: ADOrder) -> PyResult<()> {
149        let _ = self.inner.set_ad_order(ad);
150        Ok(())
151    }
152
153    fn __getitem__(&self, date: NaiveDateTime) -> Number {
154        self.inner.interpolated_value(&date)
155    }
156
157    fn __eq__(&self, other: Curve) -> bool {
158        self.inner.eq(&other.inner)
159    }
160
161    // JSON
162    /// Create a JSON string representation of the object.
163    ///
164    /// Returns
165    /// -------
166    /// str
167    #[pyo3(name = "to_json")]
168    fn to_json_py(&self) -> PyResult<String> {
169        match DeserializedObj::Curve(self.clone()).to_json() {
170            Ok(v) => Ok(v),
171            Err(_) => Err(PyValueError::new_err(
172                "Failed to serialize `Curve` to JSON.",
173            )),
174        }
175    }
176
177    // Pickling
178    pub fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
179        *self = decode_from_slice(state.as_bytes(), legacy()).unwrap().0;
180        Ok(())
181    }
182    pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
183        Ok(PyBytes::new(py, &encode_to_vec(&self, legacy()).unwrap()))
184    }
185    pub fn __getnewargs__(
186        &self,
187    ) -> PyResult<(
188        IndexMap<NaiveDateTime, Number>,
189        CurveInterpolator,
190        ADOrder,
191        String,
192        Convention,
193        Modifier,
194        Calendar,
195        Option<f64>,
196    )> {
197        Ok((
198            self.inner.nodes.index_map(),
199            self.inner.interpolator.clone(),
200            self.inner.ad(),
201            self.inner.id.clone(),
202            self.inner.convention,
203            self.inner.modifier,
204            self.inner.calendar.clone(),
205            self.inner.index_base,
206        ))
207    }
208}
209
210// /// Convert the `nodes`of a `Curve` from a `HashMap` input form into the local data model.
211// /// Will upcast f64 values to a new ADOrder adding curve variable tags by id.
212// fn hashmap_into_nodes_timestamp(
213//     h: HashMap<NaiveDateTime, Number>,
214//     ad: ADOrder,
215//     id: &str,
216// ) -> NodesTimestamp {
217//     let vars: Vec<String> = get_variable_tags(id, h.keys().len());
218//
219//     /// First convert to IndexMap and sort key order.
220//     // let mut im: IndexMap<NaiveDateTime, Number> = IndexMap::from_iter(h.into_iter());
221//     let mut im: IndexMap<i64, Number> = IndexMap::from_iter(h.into_iter().map(|(k,v)| (k.and_utc().timestamp(), v)));
222//     im.sort_keys();
223//
224//     match ad {
225//         ADOrder::Zero => { NodesTimestamp::F64(IndexMap::from_iter(im.into_iter().map(|(k,v)| (k, f64::from(v))))) }
226//         ADOrder::One => { NodesTimestamp::Dual(IndexMap::from_iter(im.into_iter().enumerate().map(|(i,(k,v))| (k, Dual::from(set_order_with_conversion(v, ad, vec![vars[i].clone()])))))) }
227//         ADOrder::Two => { NodesTimestamp::Dual2(IndexMap::from_iter(im.into_iter().enumerate().map(|(i,(k,v))| (k, Dual2::from(set_order_with_conversion(v, ad, vec![vars[i].clone()])))))) }
228//     }
229// }
230
231fn nodes_into_order(mut nodes: IndexMap<NaiveDateTime, Number>, ad: ADOrder, id: &str) -> Nodes {
232    let vars: Vec<String> = get_variable_tags(id, nodes.keys().len());
233    nodes.sort_keys();
234    match ad {
235        ADOrder::Zero => Nodes::F64(IndexMap::from_iter(
236            nodes.into_iter().map(|(k, v)| (k, f64::from(v))),
237        )),
238        ADOrder::One => {
239            Nodes::Dual(IndexMap::from_iter(nodes.into_iter().enumerate().map(
240                |(i, (k, v))| (k, Dual::from(set_order(v, ad, vec![vars[i].clone()]))),
241            )))
242        }
243        ADOrder::Two => {
244            Nodes::Dual2(IndexMap::from_iter(nodes.into_iter().enumerate().map(
245                |(i, (k, v))| (k, Dual2::from(set_order(v, ad, vec![vars[i].clone()]))),
246            )))
247        }
248    }
249}
250
251#[pymethods]
252impl Modifier {
253    // Pickling
254    #[new]
255    fn new_py(ad: u8) -> PyResult<Modifier> {
256        match ad {
257            0_u8 => Ok(Modifier::Act),
258            1_u8 => Ok(Modifier::F),
259            2_u8 => Ok(Modifier::ModF),
260            3_u8 => Ok(Modifier::P),
261            4_u8 => Ok(Modifier::ModP),
262            _ => Err(PyValueError::new_err(
263                "unreachable code on Convention pickle.",
264            )),
265        }
266    }
267    pub fn __getnewargs__<'py>(&self) -> PyResult<(u8,)> {
268        match self {
269            Modifier::Act => Ok((0_u8,)),
270            Modifier::F => Ok((1_u8,)),
271            Modifier::ModF => Ok((2_u8,)),
272            Modifier::P => Ok((3_u8,)),
273            Modifier::ModP => Ok((4_u8,)),
274        }
275    }
276}
277
278#[pyfunction]
279pub(crate) fn _get_modifier_str(modifier: Modifier) -> String {
280    match modifier {
281        Modifier::F => "F".to_string(),
282        Modifier::ModF => "MF".to_string(),
283        Modifier::P => "P".to_string(),
284        Modifier::ModP => "MP".to_string(),
285        Modifier::Act => "NONE".to_string(),
286    }
287}