1use crate::curves::interpolation::utils::index_left;
2use crate::curves::nodes::{Nodes, NodesTimestamp};
3use crate::dual::{get_variable_tags, ADOrder, Dual, Dual2, Number};
4use crate::scheduling::{Convention, DateRoll};
5use chrono::NaiveDateTime;
6use indexmap::IndexMap;
7use pyo3::exceptions::PyValueError;
8use pyo3::{pyclass, PyErr};
9use serde::{Deserialize, Serialize};
10use std::cmp::PartialEq;
11
12#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
14pub struct CurveDF<T: CurveInterpolation, U: DateRoll> {
15 pub(crate) nodes: NodesTimestamp,
16 pub(crate) interpolator: T,
17 pub(crate) id: String,
18 pub(crate) convention: Convention,
19 pub(crate) modifier: Modifier,
20 pub(crate) index_base: Option<f64>,
21 pub(crate) calendar: U,
22}
23
24#[pyclass(module = "rateslib.rs", eq, eq_int)]
26#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
27pub enum Modifier {
28 Act,
30 F,
32 ModF,
34 P,
36 ModP,
38}
39
40pub trait CurveInterpolation {
42 fn interpolated_value(&self, nodes: &NodesTimestamp, date: &NaiveDateTime) -> Number;
44
45 fn node_index(&self, nodes: &NodesTimestamp, date_timestamp: i64) -> usize {
47 index_left(&nodes.keys(), &date_timestamp, None)
49 }
50}
51
52impl<T: CurveInterpolation, U: DateRoll> CurveDF<T, U> {
53 pub fn try_new(
54 nodes: Nodes,
55 interpolator: T,
56 id: &str,
57 convention: Convention,
58 modifier: Modifier,
59 index_base: Option<f64>,
60 calendar: U,
61 ) -> Result<Self, PyErr> {
62 let mut nodes = NodesTimestamp::from(nodes);
63 nodes.sort_keys();
64 Ok(Self {
65 nodes,
66 interpolator,
67 id: id.to_string(),
68 convention,
69 modifier,
70 index_base,
71 calendar,
72 })
73 }
74
75 pub fn ad(&self) -> ADOrder {
77 match self.nodes {
78 NodesTimestamp::F64(_) => ADOrder::Zero,
79 NodesTimestamp::Dual(_) => ADOrder::One,
80 NodesTimestamp::Dual2(_) => ADOrder::Two,
81 }
82 }
83
84 pub fn interpolated_value(&self, date: &NaiveDateTime) -> Number {
85 self.interpolator.interpolated_value(&self.nodes, date)
86 }
87
88 pub fn node_index(&self, date_timestamp: i64) -> usize {
89 self.interpolator.node_index(&self.nodes, date_timestamp)
90 }
91
92 pub fn set_ad_order(&mut self, ad: ADOrder) -> Result<(), PyErr> {
93 let vars: Vec<String> = get_variable_tags(&self.id, self.nodes.keys().len());
94 match (ad, &self.nodes) {
95 (ADOrder::Zero, NodesTimestamp::F64(_))
96 | (ADOrder::One, NodesTimestamp::Dual(_))
97 | (ADOrder::Two, NodesTimestamp::Dual2(_)) => {
98 Ok(())
100 }
101 (ADOrder::One, NodesTimestamp::F64(i)) => {
102 self.nodes = NodesTimestamp::Dual(IndexMap::from_iter(
104 i.into_iter()
105 .enumerate()
106 .map(|(i, (k, v))| (*k, Dual::new(*v, vec![vars[i].clone()]))),
107 ));
108 Ok(())
109 }
110 (ADOrder::Two, NodesTimestamp::F64(i)) => {
111 self.nodes = NodesTimestamp::Dual2(IndexMap::from_iter(
113 i.into_iter()
114 .enumerate()
115 .map(|(i, (k, v))| (*k, Dual2::new(*v, vec![vars[i].clone()]))),
116 ));
117 Ok(())
118 }
119 (ADOrder::One, NodesTimestamp::Dual2(i)) => {
120 self.nodes = NodesTimestamp::Dual(IndexMap::from_iter(
121 i.into_iter().map(|(k, v)| (*k, Dual::from(v))),
122 ));
123 Ok(())
124 }
125 (ADOrder::Zero, NodesTimestamp::Dual(i)) => {
126 self.nodes = NodesTimestamp::F64(IndexMap::from_iter(
128 i.into_iter().map(|(k, v)| (*k, f64::from(v))),
129 ));
130 Ok(())
131 }
132 (ADOrder::Zero, NodesTimestamp::Dual2(i)) => {
133 self.nodes = NodesTimestamp::F64(IndexMap::from_iter(
135 i.into_iter().map(|(k, v)| (*k, f64::from(v))),
136 ));
137 Ok(())
138 }
139 (ADOrder::Two, NodesTimestamp::Dual(i)) => {
140 self.nodes = NodesTimestamp::Dual2(IndexMap::from_iter(
142 i.into_iter().map(|(k, v)| (*k, Dual2::from(v))),
143 ));
144 Ok(())
145 }
146 }
147 }
148
149 pub fn index_value(&self, date: &NaiveDateTime) -> Result<Number, PyErr> {
150 match self.index_base {
151 None => Err(PyValueError::new_err("Can only calculate `index_value` for a Curve which has been initialised with `index_base`.")),
152 Some(ib) => {
153 if date.and_utc().timestamp() < self.nodes.first_key() {
154 Ok(Number::F64(0.0))
155 } else {
156 Ok(Number::F64(ib) / self.interpolated_value(date))
157 }
158 }
159 }
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::curves::LogLinearInterpolator;
167 use crate::scheduling::{ndt, NamedCal};
168 use indexmap::IndexMap;
169
170 fn curve_fixture() -> CurveDF<LogLinearInterpolator, NamedCal> {
171 let nodes = Nodes::F64(IndexMap::from_iter(vec![
172 (ndt(2000, 1, 1), 1.0_f64),
173 (ndt(2001, 1, 1), 0.99_f64),
174 (ndt(2002, 1, 1), 0.98_f64),
175 ]));
176 let interpolator = LogLinearInterpolator::new();
177 let convention = Convention::Act360;
178 let modifier = Modifier::ModF;
179 let cal = NamedCal::try_new("all").unwrap();
180 CurveDF::try_new(nodes, interpolator, "crv", convention, modifier, None, cal).unwrap()
181 }
182
183 fn index_curve_fixture() -> CurveDF<LogLinearInterpolator, NamedCal> {
184 let nodes = Nodes::F64(IndexMap::from_iter(vec![
185 (ndt(2000, 1, 1), 1.0_f64),
186 (ndt(2001, 1, 1), 0.99_f64),
187 (ndt(2002, 1, 1), 0.98_f64),
188 ]));
189 let interpolator = LogLinearInterpolator::new();
190 let convention = Convention::Act360;
191 let modifier = Modifier::ModF;
192 let cal = NamedCal::try_new("all").unwrap();
193 CurveDF::try_new(
194 nodes,
195 interpolator,
196 "crv",
197 convention,
198 modifier,
199 Some(100.0),
200 cal,
201 )
202 .unwrap()
203 }
204
205 fn curve_dual_fixture() -> CurveDF<LogLinearInterpolator, NamedCal> {
206 let nodes = Nodes::Dual(IndexMap::from_iter(vec![
207 (ndt(2000, 1, 1), Dual::new(1.0, vec!["x".to_string()])),
208 (ndt(2001, 1, 1), Dual::new(0.99, vec!["y".to_string()])),
209 (ndt(2002, 1, 1), Dual::new(0.98, vec!["z".to_string()])),
210 ]));
211 let interpolator = LogLinearInterpolator::new();
212 let convention = Convention::Act360;
213 let modifier = Modifier::ModF;
214 let cal = NamedCal::try_new("all").unwrap();
215 CurveDF::try_new(nodes, interpolator, "crv", convention, modifier, None, cal).unwrap()
216 }
217
218 #[test]
219 fn test_get_index() {
220 let c = curve_fixture();
221 let result = c.node_index(ndt(2001, 7, 30).and_utc().timestamp());
222 assert_eq!(result, 1_usize)
223 }
224
225 #[test]
226 fn test_get_value() {
227 let c = curve_fixture();
228 let result = c.interpolated_value(&ndt(2000, 7, 1));
229 assert_eq!(result, Number::F64(0.9950147597711371))
230 }
231
232 fn nodes_timestamp_fixture() -> NodesTimestamp {
233 let nodes = Nodes::F64(IndexMap::from_iter(vec![
234 (ndt(2000, 1, 1), 1.0_f64),
235 (ndt(2001, 1, 1), 0.99_f64),
236 (ndt(2002, 1, 1), 0.98_f64),
237 ]));
238 NodesTimestamp::from(nodes)
239 }
240
241 #[test]
242 fn test_log_linear() {
243 let nts = nodes_timestamp_fixture();
244 let ll = LogLinearInterpolator::new();
245 let result = ll.interpolated_value(&nts, &ndt(2000, 7, 1));
246 assert_eq!(result, Number::F64(0.9950147597711371));
248 }
249
250 #[test]
251 fn test_set_order() {
252 let mut curve = curve_fixture();
254 let _ = curve.set_ad_order(ADOrder::One);
255 let result = curve.interpolated_value(&ndt(2001, 1, 1));
256 assert_eq!(
257 result,
258 Number::Dual(Dual::new(0.99, vec!["crv1".to_string()]))
259 );
260 }
261
262 #[test]
263 fn test_set_order_no_change() {
264 let mut curve = curve_dual_fixture();
266 let _ = curve.set_ad_order(ADOrder::One);
267 let result = curve.interpolated_value(&ndt(2001, 1, 1));
268 assert_eq!(result, Number::Dual(Dual::new(0.99, vec!["y".to_string()])));
269 }
270
271 #[test]
272 fn test_set_order_vars_remain() {
273 let mut curve = curve_dual_fixture();
275 let _ = curve.set_ad_order(ADOrder::Two);
276 let result = curve.interpolated_value(&ndt(2001, 1, 1));
277 assert_eq!(
278 result,
279 Number::Dual2(Dual2::new(0.99, vec!["y".to_string()]))
280 );
281 }
282
283 #[test]
284 fn test_index_value() {
285 let index_curve = index_curve_fixture();
286 let result = index_curve.index_value(&ndt(2001, 1, 1)).unwrap();
287 assert_eq!(result, Number::F64(100.0 / 0.99))
288 }
289
290 #[test]
291 fn test_index_value_prior_to_first() {
292 let index_curve = index_curve_fixture();
293 let result = index_curve.index_value(&ndt(1980, 1, 1)).unwrap();
294 assert_eq!(result, Number::F64(0.0))
295 }
296}