1use crate::curves::nodes::{Nodes, NodesTimestamp};
4use crate::curves::{
5 CurveDF, CurveInterpolation, FlatBackwardInterpolator, FlatForwardInterpolator,
6 LinearInterpolator, LinearZeroRateInterpolator, LogLinearInterpolator, Modifier,
7 NullInterpolator,
8};
9use crate::dual::{get_variable_tags, set_order, ADOrder, Dual, Dual2, Number};
10use crate::json::json_py::DeserializedObj;
11use crate::json::JSON;
12use crate::scheduling::{Calendar, Convention};
13use bincode::config::legacy;
14use bincode::serde::{decode_from_slice, encode_to_vec};
15use chrono::NaiveDateTime;
16use indexmap::IndexMap;
17use pyo3::exceptions::PyValueError;
18use pyo3::prelude::*;
19use pyo3::types::PyBytes;
20use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, PartialEq, FromPyObject, Deserialize, Serialize, IntoPyObject)]
24pub(crate) enum CurveInterpolator {
25 LogLinear(LogLinearInterpolator),
26 Linear(LinearInterpolator),
27 LinearZeroRate(LinearZeroRateInterpolator),
28 FlatForward(FlatForwardInterpolator),
29 FlatBackward(FlatBackwardInterpolator),
30 Null(NullInterpolator),
31}
32
33impl CurveInterpolation for CurveInterpolator {
54 fn interpolated_value(&self, nodes: &NodesTimestamp, date: &NaiveDateTime) -> Number {
55 match self {
56 CurveInterpolator::LogLinear(i) => i.interpolated_value(nodes, date),
57 CurveInterpolator::Linear(i) => i.interpolated_value(nodes, date),
58 CurveInterpolator::LinearZeroRate(i) => i.interpolated_value(nodes, date),
59 CurveInterpolator::FlatBackward(i) => i.interpolated_value(nodes, date),
60 CurveInterpolator::FlatForward(i) => i.interpolated_value(nodes, date),
61 CurveInterpolator::Null(i) => i.interpolated_value(nodes, date),
62 }
63 }
64}
65
66#[pyclass(module = "rateslib.rs")]
67#[derive(Clone, Deserialize, Serialize)]
68pub(crate) struct Curve {
69 inner: CurveDF<CurveInterpolator, Calendar>,
70}
71
72#[pymethods]
73impl Curve {
74 #[new]
75 #[pyo3(signature = (nodes, interpolator, ad, id, convention, modifier, calendar, index_base=None))]
76 fn new_py(
77 nodes: IndexMap<NaiveDateTime, Number>,
78 interpolator: CurveInterpolator,
79 ad: ADOrder,
80 id: String,
81 convention: Convention,
82 modifier: Modifier,
83 calendar: Calendar,
84 index_base: Option<f64>,
85 ) -> PyResult<Self> {
86 let nodes_ = nodes_into_order(nodes, ad, &id);
87 let inner = CurveDF::try_new(
88 nodes_,
89 interpolator,
90 &id,
91 convention,
92 modifier,
93 index_base,
94 calendar,
95 )?;
96 Ok(Self { inner })
97 }
98
99 #[getter]
100 fn id(&self) -> String {
101 self.inner.id.clone()
102 }
103
104 #[getter]
105 fn nodes(&self) -> IndexMap<NaiveDateTime, Number> {
106 let nodes = Nodes::from(self.inner.nodes.clone());
107 match nodes {
108 Nodes::F64(i) => IndexMap::from_iter(i.into_iter().map(|(k, v)| (k, Number::F64(v)))),
109 Nodes::Dual(i) => IndexMap::from_iter(i.into_iter().map(|(k, v)| (k, Number::Dual(v)))),
110 Nodes::Dual2(i) => {
111 IndexMap::from_iter(i.into_iter().map(|(k, v)| (k, Number::Dual2(v))))
112 }
113 }
114 }
115
116 #[getter]
117 fn ad(&self) -> ADOrder {
118 self.inner.ad()
119 }
120
121 #[getter]
122 fn interpolation(&self) -> String {
123 match self.inner.interpolator {
124 CurveInterpolator::Linear(_) => "linear".to_string(),
125 CurveInterpolator::LogLinear(_) => "log_linear".to_string(),
126 CurveInterpolator::LinearZeroRate(_) => "linear_zero_rate".to_string(),
127 CurveInterpolator::FlatForward(_) => "flat_forward".to_string(),
128 CurveInterpolator::FlatBackward(_) => "flat_backward".to_string(),
129 CurveInterpolator::Null(_) => "null".to_string(),
130 }
131 }
132
133 #[getter]
134 fn convention(&self) -> Convention {
135 self.inner.convention
136 }
137
138 #[getter]
139 fn modifier(&self) -> Modifier {
140 self.inner.modifier
141 }
142
143 #[pyo3(name = "index_value")]
144 fn index_value_py(&self, date: NaiveDateTime) -> PyResult<Number> {
145 self.inner.index_value(&date)
146 }
147
148 fn set_ad_order(&mut self, ad: ADOrder) -> PyResult<()> {
149 let _ = self.inner.set_ad_order(ad);
150 Ok(())
151 }
152
153 fn __getitem__(&self, date: NaiveDateTime) -> Number {
154 self.inner.interpolated_value(&date)
155 }
156
157 fn __eq__(&self, other: Curve) -> bool {
158 self.inner.eq(&other.inner)
159 }
160
161 #[pyo3(name = "to_json")]
168 fn to_json_py(&self) -> PyResult<String> {
169 match DeserializedObj::Curve(self.clone()).to_json() {
170 Ok(v) => Ok(v),
171 Err(_) => Err(PyValueError::new_err(
172 "Failed to serialize `Curve` to JSON.",
173 )),
174 }
175 }
176
177 pub fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
179 *self = decode_from_slice(state.as_bytes(), legacy()).unwrap().0;
180 Ok(())
181 }
182 pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
183 Ok(PyBytes::new(py, &encode_to_vec(&self, legacy()).unwrap()))
184 }
185 pub fn __getnewargs__(
186 &self,
187 ) -> PyResult<(
188 IndexMap<NaiveDateTime, Number>,
189 CurveInterpolator,
190 ADOrder,
191 String,
192 Convention,
193 Modifier,
194 Calendar,
195 Option<f64>,
196 )> {
197 Ok((
198 self.inner.nodes.index_map(),
199 self.inner.interpolator.clone(),
200 self.inner.ad(),
201 self.inner.id.clone(),
202 self.inner.convention,
203 self.inner.modifier,
204 self.inner.calendar.clone(),
205 self.inner.index_base,
206 ))
207 }
208}
209
210fn nodes_into_order(mut nodes: IndexMap<NaiveDateTime, Number>, ad: ADOrder, id: &str) -> Nodes {
232 let vars: Vec<String> = get_variable_tags(id, nodes.keys().len());
233 nodes.sort_keys();
234 match ad {
235 ADOrder::Zero => Nodes::F64(IndexMap::from_iter(
236 nodes.into_iter().map(|(k, v)| (k, f64::from(v))),
237 )),
238 ADOrder::One => {
239 Nodes::Dual(IndexMap::from_iter(nodes.into_iter().enumerate().map(
240 |(i, (k, v))| (k, Dual::from(set_order(v, ad, vec![vars[i].clone()]))),
241 )))
242 }
243 ADOrder::Two => {
244 Nodes::Dual2(IndexMap::from_iter(nodes.into_iter().enumerate().map(
245 |(i, (k, v))| (k, Dual2::from(set_order(v, ad, vec![vars[i].clone()]))),
246 )))
247 }
248 }
249}
250
251#[pymethods]
252impl Modifier {
253 #[new]
255 fn new_py(ad: u8) -> PyResult<Modifier> {
256 match ad {
257 0_u8 => Ok(Modifier::Act),
258 1_u8 => Ok(Modifier::F),
259 2_u8 => Ok(Modifier::ModF),
260 3_u8 => Ok(Modifier::P),
261 4_u8 => Ok(Modifier::ModP),
262 _ => Err(PyValueError::new_err(
263 "unreachable code on Convention pickle.",
264 )),
265 }
266 }
267 pub fn __getnewargs__<'py>(&self) -> PyResult<(u8,)> {
268 match self {
269 Modifier::Act => Ok((0_u8,)),
270 Modifier::F => Ok((1_u8,)),
271 Modifier::ModF => Ok((2_u8,)),
272 Modifier::P => Ok((3_u8,)),
273 Modifier::ModP => Ok((4_u8,)),
274 }
275 }
276}
277
278#[pyfunction]
279pub(crate) fn _get_modifier_str(modifier: Modifier) -> String {
280 match modifier {
281 Modifier::F => "F".to_string(),
282 Modifier::ModF => "MF".to_string(),
283 Modifier::P => "P".to_string(),
284 Modifier::ModP => "MP".to_string(),
285 Modifier::Act => "NONE".to_string(),
286 }
287}