1use chrono::prelude::*;
2use chrono::Days;
3use indexmap::IndexSet;
4use pyo3::exceptions::PyValueError;
5use pyo3::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::cmp::{Eq, PartialEq};
8
9use crate::scheduling::{Adjuster, Adjustment, Calendar, Imm};
10
11#[pyclass(module = "rateslib.rs", eq)]
13#[derive(Debug, Copy, Hash, Clone, PartialEq, Eq, Deserialize, Serialize)]
14pub enum RollDay {
15    Day(u32),
17    IMM(),
19}
20
21impl RollDay {
22    pub fn vec_from(udates: &Vec<NaiveDateTime>) -> Vec<Self> {
50        let mut set: IndexSet<RollDay> = IndexSet::new();
51
52        for udate in udates {
53            let mut v: Vec<Self> = vec![RollDay::Day(udate.day())];
55            if Imm::Eom.validate(udate) {
57                let mut day = udate.day() + 1;
58                while day < 32 {
59                    v.push(RollDay::Day(day));
60                    day = day + 1;
61                }
62            }
63            if Imm::Wed3.validate(udate) {
65                v.push(RollDay::IMM())
66            }
67            set.append(&mut IndexSet::<RollDay>::from_iter(v));
69        }
70        set.into_iter().collect()
71    }
72
73    pub fn try_udate(&self, udate: &NaiveDateTime) -> Result<NaiveDateTime, PyErr> {
85        let msg = "`udate` does not align with given `RollDay`.".to_string();
86        match self {
87            RollDay::Day(31) => {
88                if Imm::Eom.validate(udate) {
89                    Ok(*udate)
90                } else {
91                    Err(PyValueError::new_err(msg))
92                }
93            }
94            RollDay::Day(30) => {
95                if (Imm::Eom.validate(udate) && udate.day() < 30) || udate.day() == 30 {
96                    Ok(*udate)
97                } else {
98                    Err(PyValueError::new_err(msg))
99                }
100            }
101            RollDay::Day(29) => {
102                if (Imm::Eom.validate(udate) && udate.day() < 29) || udate.day() == 29 {
103                    Ok(*udate)
104                } else {
105                    Err(PyValueError::new_err(msg))
106                }
107            }
108            RollDay::IMM() => {
109                if Imm::Wed3.validate(udate) {
110                    Ok(*udate)
111                } else {
112                    Err(PyValueError::new_err(msg))
113                }
114            }
115            RollDay::Day(value) => {
116                if udate.day() == *value {
117                    Ok(*udate)
118                } else {
119                    Err(PyValueError::new_err(msg))
120                }
121            }
122        }
123    }
124
125    pub fn try_uadd(&self, udate: &NaiveDateTime, months: i32) -> Result<NaiveDateTime, PyErr> {
140        let _ = self.try_udate(udate)?;
141        Ok(self.uadd(udate, months))
142    }
143
144    pub fn uadd(&self, udate: &NaiveDateTime, months: i32) -> NaiveDateTime {
153        let mut yr_roll = (months.abs() / 12) * months.signum();
155        let rem_months = months - yr_roll * 12;
156
157        let mut new_month = i32::try_from(udate.month()).unwrap() + rem_months;
159        if new_month <= 0 {
160            yr_roll -= 1;
161            new_month = new_month.rem_euclid(12);
162        } else if new_month >= 13 {
163            yr_roll += 1;
164            new_month = new_month.rem_euclid(12);
165        }
166        if new_month == 0 {
167            new_month = 12;
168        }
169
170        self.try_from_ym(udate.year() + yr_roll, new_month.try_into().unwrap())
172            .unwrap()
173    }
174
175    pub fn try_from_ym(&self, year: i32, month: u32) -> Result<NaiveDateTime, PyErr> {
185        match self {
186            RollDay::Day(value) => Ok(get_roll_by_day(year, month, *value)),
187            RollDay::IMM {} => Imm::Wed3.from_ym_opt(year, month),
188        }
189    }
190}
191
192pub(crate) fn get_unadjusteds(
197    date: &NaiveDateTime,
198    adjuster: &Adjuster,
199    calendar: &Calendar,
200) -> Vec<NaiveDateTime> {
201    let mut udates: Vec<NaiveDateTime> = vec![];
202
203    let mut udate = *date;
205    udates.push(udate);
206
207    udate = *date - Days::new(1);
209    while adjuster.adjust(&udate, calendar) == *date {
210        udates.push(udate);
211        udate = udate - Days::new(1);
212    }
213
214    udate = *date + Days::new(1);
216    while adjuster.adjust(&udate, calendar) == *date {
217        udates.push(udate);
218        udate = udate + Days::new(1);
219    }
220
221    udates
222}
223
224fn get_roll_by_day(year: i32, month: u32, day: u32) -> NaiveDateTime {
226    let d = NaiveDate::from_ymd_opt(year, month, day);
227    match d {
228        Some(date) => NaiveDateTime::new(date, NaiveTime::from_hms_opt(0, 0, 0).unwrap()),
229        None => {
230            if day > 28 {
231                get_roll_by_day(year, month, day - 1)
232            } else {
233                panic!("Unexpected error in `get_roll_by_day`")
234            }
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use crate::scheduling::{ndt, Cal};
243
244    fn fixture_bus_cal() -> Calendar {
245        Cal::try_from_name("bus").unwrap().into()
246    }
247
248    #[test]
249    fn test_rollday_equality() {
250        let rd1 = RollDay::IMM();
251        let rd2 = RollDay::IMM();
252        assert_eq!(rd1, rd2);
253
254        let rd1 = RollDay::IMM();
255        let rd2 = RollDay::Day(21);
256        assert_ne!(rd1, rd2);
257
258        let rd1 = RollDay::Day(20);
259        let rd2 = RollDay::Day(20);
260        assert_eq!(rd1, rd2);
261
262        let rd1 = RollDay::Day(21);
263        let rd2 = RollDay::Day(9);
264        assert_ne!(rd1, rd2);
265    }
266
267    #[test]
268    fn test_rollday_try_udate() {
269        let options: Vec<(RollDay, NaiveDateTime)> = vec![
270            (RollDay::Day(15), ndt(2000, 3, 15)),
271            (RollDay::Day(31), ndt(2000, 3, 31)),
272            (RollDay::Day(31), ndt(2022, 2, 28)),
273            (RollDay::Day(30), ndt(2024, 2, 29)),
274            (RollDay::Day(31), ndt(2024, 2, 29)),
275        ];
276        for option in options {
277            assert_eq!(false, option.0.try_udate(&option.1).is_err());
278        }
279    }
280
281    #[test]
282    fn test_get_unadjusteds() {
283        let options: Vec<(NaiveDateTime, Vec<NaiveDateTime>)> = vec![
284            (ndt(2000, 2, 29), vec![ndt(2000, 2, 29)]),
285            (
286                ndt(2025, 11, 28),
287                vec![ndt(2025, 11, 28), ndt(2025, 11, 29), ndt(2025, 11, 30)],
288            ),
289            (
290                ndt(2025, 2, 3),
291                vec![ndt(2025, 2, 3), ndt(2025, 2, 2), ndt(2025, 2, 1)],
292            ),
293        ];
294
295        for option in options {
296            let result = get_unadjusteds(
297                &option.0,
298                &Adjuster::ModifiedFollowing {},
299                &fixture_bus_cal(),
300            );
301
302            assert_eq!(result, option.1);
303        }
304    }
305
306    #[test]
307    fn test_vec_from() {
308        let options: Vec<(Vec<NaiveDateTime>, Vec<RollDay>)> = vec![
309            (
310                vec![ndt(2000, 2, 29)],
311                vec![RollDay::Day(29), RollDay::Day(30), RollDay::Day(31)],
312            ),
313            (vec![ndt(2025, 11, 28)], vec![RollDay::Day(28)]),
314            (
315                vec![ndt(2025, 3, 19)],
316                vec![RollDay::Day(19), RollDay::IMM {}],
317            ),
318            (vec![ndt(2025, 9, 15)], vec![RollDay::Day(15)]),
319        ];
320
321        for option in options {
322            let result = RollDay::vec_from(&option.0);
323            assert_eq!(result, option.1);
324        }
325    }
326
327    #[test]
328    fn test_vec_from_multiple() {
329        let options: Vec<(Vec<NaiveDateTime>, Vec<RollDay>)> = vec![
330            (
331                vec![ndt(2000, 2, 29)],
332                vec![RollDay::Day(29), RollDay::Day(30), RollDay::Day(31)],
333            ),
334            (
335                vec![ndt(2025, 11, 28), ndt(2025, 11, 29), ndt(2025, 11, 30)],
336                vec![
337                    RollDay::Day(28),
338                    RollDay::Day(29),
339                    RollDay::Day(30),
340                    RollDay::Day(31),
341                ],
342            ),
343            (
344                vec![ndt(2025, 3, 19)],
345                vec![RollDay::Day(19), RollDay::IMM()],
346            ),
347            (
348                vec![ndt(2025, 9, 15), ndt(2025, 9, 14), ndt(2025, 9, 13)],
349                vec![RollDay::Day(15), RollDay::Day(14), RollDay::Day(13)],
350            ),
351        ];
352
353        for option in options {
354            let result = RollDay::vec_from(&option.0);
355            assert_eq!(result, option.1);
356        }
357    }
358}