1use chrono::prelude::*;
2use chrono::Days;
3use indexmap::IndexSet;
4use pyo3::exceptions::PyValueError;
5use pyo3::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::cmp::{Eq, PartialEq};
8
9use crate::scheduling::{Adjuster, Adjustment, Calendar, Imm};
10
11#[pyclass(module = "rateslib.rs", eq)]
13#[derive(Debug, Copy, Hash, Clone, PartialEq, Eq, Deserialize, Serialize)]
14pub enum RollDay {
15 Day(u32),
17 IMM(),
19}
20
21impl RollDay {
22 pub fn vec_from(udates: &Vec<NaiveDateTime>) -> Vec<Self> {
50 let mut set: IndexSet<RollDay> = IndexSet::new();
51
52 for udate in udates {
53 let mut v: Vec<Self> = vec![RollDay::Day(udate.day())];
55 if Imm::Eom.validate(udate) {
57 let mut day = udate.day() + 1;
58 while day < 32 {
59 v.push(RollDay::Day(day));
60 day = day + 1;
61 }
62 }
63 if Imm::Wed3.validate(udate) {
65 v.push(RollDay::IMM())
66 }
67 set.append(&mut IndexSet::<RollDay>::from_iter(v));
69 }
70 set.into_iter().collect()
71 }
72
73 pub fn try_udate(&self, udate: &NaiveDateTime) -> Result<NaiveDateTime, PyErr> {
85 let msg = "`udate` does not align with given `RollDay`.".to_string();
86 match self {
87 RollDay::Day(31) => {
88 if Imm::Eom.validate(udate) {
89 Ok(*udate)
90 } else {
91 Err(PyValueError::new_err(msg))
92 }
93 }
94 RollDay::Day(30) => {
95 if (Imm::Eom.validate(udate) && udate.day() < 30) || udate.day() == 30 {
96 Ok(*udate)
97 } else {
98 Err(PyValueError::new_err(msg))
99 }
100 }
101 RollDay::Day(29) => {
102 if (Imm::Eom.validate(udate) && udate.day() < 29) || udate.day() == 29 {
103 Ok(*udate)
104 } else {
105 Err(PyValueError::new_err(msg))
106 }
107 }
108 RollDay::IMM() => {
109 if Imm::Wed3.validate(udate) {
110 Ok(*udate)
111 } else {
112 Err(PyValueError::new_err(msg))
113 }
114 }
115 RollDay::Day(value) => {
116 if udate.day() == *value {
117 Ok(*udate)
118 } else {
119 Err(PyValueError::new_err(msg))
120 }
121 }
122 }
123 }
124
125 pub fn try_uadd(&self, udate: &NaiveDateTime, months: i32) -> Result<NaiveDateTime, PyErr> {
140 let _ = self.try_udate(udate)?;
141 Ok(self.uadd(udate, months))
142 }
143
144 pub fn uadd(&self, udate: &NaiveDateTime, months: i32) -> NaiveDateTime {
153 let mut yr_roll = (months.abs() / 12) * months.signum();
155 let rem_months = months - yr_roll * 12;
156
157 let mut new_month = i32::try_from(udate.month()).unwrap() + rem_months;
159 if new_month <= 0 {
160 yr_roll -= 1;
161 new_month = new_month.rem_euclid(12);
162 } else if new_month >= 13 {
163 yr_roll += 1;
164 new_month = new_month.rem_euclid(12);
165 }
166 if new_month == 0 {
167 new_month = 12;
168 }
169
170 self.try_from_ym(udate.year() + yr_roll, new_month.try_into().unwrap())
172 .unwrap()
173 }
174
175 pub fn try_from_ym(&self, year: i32, month: u32) -> Result<NaiveDateTime, PyErr> {
185 match self {
186 RollDay::Day(value) => Ok(get_roll_by_day(year, month, *value)),
187 RollDay::IMM {} => Imm::Wed3.from_ym_opt(year, month),
188 }
189 }
190}
191
192pub(crate) fn get_unadjusteds(
197 date: &NaiveDateTime,
198 adjuster: &Adjuster,
199 calendar: &Calendar,
200) -> Vec<NaiveDateTime> {
201 let mut udates: Vec<NaiveDateTime> = vec![];
202
203 let mut udate = *date;
205 udates.push(udate);
206
207 udate = *date - Days::new(1);
209 while adjuster.adjust(&udate, calendar) == *date {
210 udates.push(udate);
211 udate = udate - Days::new(1);
212 }
213
214 udate = *date + Days::new(1);
216 while adjuster.adjust(&udate, calendar) == *date {
217 udates.push(udate);
218 udate = udate + Days::new(1);
219 }
220
221 udates
222}
223
224fn get_roll_by_day(year: i32, month: u32, day: u32) -> NaiveDateTime {
226 let d = NaiveDate::from_ymd_opt(year, month, day);
227 match d {
228 Some(date) => NaiveDateTime::new(date, NaiveTime::from_hms_opt(0, 0, 0).unwrap()),
229 None => {
230 if day > 28 {
231 get_roll_by_day(year, month, day - 1)
232 } else {
233 panic!("Unexpected error in `get_roll_by_day`")
234 }
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::scheduling::{ndt, Cal};
243
244 fn fixture_bus_cal() -> Calendar {
245 Cal::try_from_name("bus").unwrap().into()
246 }
247
248 #[test]
249 fn test_rollday_equality() {
250 let rd1 = RollDay::IMM();
251 let rd2 = RollDay::IMM();
252 assert_eq!(rd1, rd2);
253
254 let rd1 = RollDay::IMM();
255 let rd2 = RollDay::Day(21);
256 assert_ne!(rd1, rd2);
257
258 let rd1 = RollDay::Day(20);
259 let rd2 = RollDay::Day(20);
260 assert_eq!(rd1, rd2);
261
262 let rd1 = RollDay::Day(21);
263 let rd2 = RollDay::Day(9);
264 assert_ne!(rd1, rd2);
265 }
266
267 #[test]
268 fn test_rollday_try_udate() {
269 let options: Vec<(RollDay, NaiveDateTime)> = vec![
270 (RollDay::Day(15), ndt(2000, 3, 15)),
271 (RollDay::Day(31), ndt(2000, 3, 31)),
272 (RollDay::Day(31), ndt(2022, 2, 28)),
273 (RollDay::Day(30), ndt(2024, 2, 29)),
274 (RollDay::Day(31), ndt(2024, 2, 29)),
275 ];
276 for option in options {
277 assert_eq!(false, option.0.try_udate(&option.1).is_err());
278 }
279 }
280
281 #[test]
282 fn test_get_unadjusteds() {
283 let options: Vec<(NaiveDateTime, Vec<NaiveDateTime>)> = vec![
284 (ndt(2000, 2, 29), vec![ndt(2000, 2, 29)]),
285 (
286 ndt(2025, 11, 28),
287 vec![ndt(2025, 11, 28), ndt(2025, 11, 29), ndt(2025, 11, 30)],
288 ),
289 (
290 ndt(2025, 2, 3),
291 vec![ndt(2025, 2, 3), ndt(2025, 2, 2), ndt(2025, 2, 1)],
292 ),
293 ];
294
295 for option in options {
296 let result = get_unadjusteds(
297 &option.0,
298 &Adjuster::ModifiedFollowing {},
299 &fixture_bus_cal(),
300 );
301
302 assert_eq!(result, option.1);
303 }
304 }
305
306 #[test]
307 fn test_vec_from() {
308 let options: Vec<(Vec<NaiveDateTime>, Vec<RollDay>)> = vec![
309 (
310 vec![ndt(2000, 2, 29)],
311 vec![RollDay::Day(29), RollDay::Day(30), RollDay::Day(31)],
312 ),
313 (vec![ndt(2025, 11, 28)], vec![RollDay::Day(28)]),
314 (
315 vec![ndt(2025, 3, 19)],
316 vec![RollDay::Day(19), RollDay::IMM {}],
317 ),
318 (vec![ndt(2025, 9, 15)], vec![RollDay::Day(15)]),
319 ];
320
321 for option in options {
322 let result = RollDay::vec_from(&option.0);
323 assert_eq!(result, option.1);
324 }
325 }
326
327 #[test]
328 fn test_vec_from_multiple() {
329 let options: Vec<(Vec<NaiveDateTime>, Vec<RollDay>)> = vec![
330 (
331 vec![ndt(2000, 2, 29)],
332 vec![RollDay::Day(29), RollDay::Day(30), RollDay::Day(31)],
333 ),
334 (
335 vec![ndt(2025, 11, 28), ndt(2025, 11, 29), ndt(2025, 11, 30)],
336 vec![
337 RollDay::Day(28),
338 RollDay::Day(29),
339 RollDay::Day(30),
340 RollDay::Day(31),
341 ],
342 ),
343 (
344 vec![ndt(2025, 3, 19)],
345 vec![RollDay::Day(19), RollDay::IMM()],
346 ),
347 (
348 vec![ndt(2025, 9, 15), ndt(2025, 9, 14), ndt(2025, 9, 13)],
349 vec![RollDay::Day(15), RollDay::Day(14), RollDay::Day(13)],
350 ),
351 ];
352
353 for option in options {
354 let result = RollDay::vec_from(&option.0);
355 assert_eq!(result, option.1);
356 }
357 }
358}