1use crate::json::json_py::DeserializedObj;
4use crate::json::JSON;
5use crate::scheduling::py::adjuster::get_roll_adjuster_from_str;
6use crate::scheduling::{
7 Adjuster, Adjustment, Cal, Calendar, CalendarAdjustment, Convention, DateRoll, NamedCal,
8 PyAdjuster, RollDay, UnionCal,
9};
10use bincode::config::legacy;
11use bincode::serde::{decode_from_slice, encode_to_vec};
12use chrono::NaiveDateTime;
13use indexmap::set::IndexSet;
14use pyo3::exceptions::PyValueError;
15use pyo3::prelude::*;
16use pyo3::types::{PyBytes, PyType};
17use std::collections::HashSet;
18
19#[pymethods]
20impl Convention {
21 #[new]
23 fn new_py(ad: u8) -> PyResult<Convention> {
24 match ad {
25 0_u8 => Ok(Convention::One),
26 1_u8 => Ok(Convention::OnePlus),
27 2_u8 => Ok(Convention::Act365F),
28 3_u8 => Ok(Convention::Act365FPlus),
29 4_u8 => Ok(Convention::Act360),
30 5_u8 => Ok(Convention::ThirtyE360),
31 6_u8 => Ok(Convention::Thirty360),
32 7_u8 => Ok(Convention::Thirty360ISDA),
33 8_u8 => Ok(Convention::ActActISDA),
34 9_u8 => Ok(Convention::ActActICMA),
35 10_u8 => Ok(Convention::Bus252),
36 _ => Err(PyValueError::new_err(
37 "unreachable code on Convention pickle.",
38 )),
39 }
40 }
41 pub fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
42 *self = decode_from_slice(state.as_bytes(), legacy()).unwrap().0;
43 Ok(())
44 }
45 pub fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
46 Ok(PyBytes::new(py, &encode_to_vec(&self, legacy()).unwrap()))
47 }
48 pub fn __getnewargs__<'py>(&self) -> PyResult<(u8,)> {
49 match self {
50 Convention::One => Ok((0_u8,)),
51 Convention::OnePlus => Ok((1_u8,)),
52 Convention::Act365F => Ok((2_u8,)),
53 Convention::Act365FPlus => Ok((3_u8,)),
54 Convention::Act360 => Ok((4_u8,)),
55 Convention::ThirtyE360 => Ok((5_u8,)),
56 Convention::Thirty360 => Ok((6_u8,)),
57 Convention::Thirty360ISDA => Ok((7_u8,)),
58 Convention::ActActISDA => Ok((8_u8,)),
59 Convention::ActActICMA => Ok((9_u8,)),
60 Convention::Bus252 => Ok((10_u8,)),
61 }
62 }
63}
64
65#[pymethods]
66impl Cal {
67 #[new]
76 fn new_py(holidays: Vec<NaiveDateTime>, week_mask: Vec<u8>) -> PyResult<Self> {
77 Ok(Cal::new(holidays, week_mask))
78 }
79
80 #[classmethod]
90 #[pyo3(name = "from_name")]
91 fn from_name_py(_cls: &Bound<'_, PyType>, name: String) -> PyResult<Self> {
92 Cal::try_from_name(&name)
93 }
94
95 #[getter]
97 fn holidays(&self) -> PyResult<Vec<NaiveDateTime>> {
98 Ok(self.holidays.clone().into_iter().collect())
99 }
100
101 #[getter]
103 fn week_mask(&self) -> PyResult<HashSet<u8>> {
104 Ok(HashSet::from_iter(
105 self.week_mask
106 .clone()
107 .into_iter()
108 .map(|x| x.num_days_from_monday() as u8),
109 ))
110 }
111
112 #[pyo3(name = "is_bus_day")]
128 fn is_bus_day_py(&self, date: NaiveDateTime) -> bool {
129 self.is_bus_day(&date)
130 }
131
132 #[pyo3(name = "is_non_bus_day")]
143 fn is_non_bus_day_py(&self, date: NaiveDateTime) -> bool {
144 self.is_non_bus_day(&date)
145 }
146
147 #[pyo3(name = "is_settlement")]
163 fn is_settlement_py(&self, date: NaiveDateTime) -> bool {
164 self.is_settlement(&date)
165 }
166
167 #[pyo3(name = "add_cal_days")]
182 fn add_cal_days_py(
183 &self,
184 date: NaiveDateTime,
185 days: i32,
186 adjuster: PyAdjuster,
187 ) -> PyResult<NaiveDateTime> {
188 Ok(self.add_cal_days(&date, days, &adjuster.into()))
189 }
190
191 #[pyo3(name = "add_bus_days")]
218 fn add_bus_days_py(
219 &self,
220 date: NaiveDateTime,
221 days: i32,
222 settlement: bool,
223 ) -> PyResult<NaiveDateTime> {
224 self.add_bus_days(&date, days, settlement)
225 }
226
227 #[pyo3(name = "add_months")]
244 fn add_months_py(
245 &self,
246 date: NaiveDateTime,
247 months: i32,
248 adjuster: PyAdjuster,
249 roll: Option<RollDay>,
250 ) -> NaiveDateTime {
251 let roll_ = match roll {
252 Some(val) => val,
253 None => RollDay::vec_from(&vec![date])[0],
254 };
255 let adjuster: Adjuster = adjuster.into();
256 adjuster.adjust(&roll_.uadd(&date, months), self)
257 }
258
259 #[pyo3(name = "roll")]
274 fn roll_py(
275 &self,
276 date: NaiveDateTime,
277 modifier: &str,
278 settlement: bool,
279 ) -> PyResult<NaiveDateTime> {
280 let adjuster = get_roll_adjuster_from_str((&modifier.to_lowercase(), settlement))?;
281 Ok(self.adjust(&date, &adjuster))
282 }
283
284 #[pyo3(name = "adjust")]
297 fn adjust_py(&self, date: NaiveDateTime, adjuster: PyAdjuster) -> PyResult<NaiveDateTime> {
298 Ok(self.adjust(&date, &adjuster.into()))
299 }
300
301 #[pyo3(name = "adjusts")]
314 fn adjusts_py(
315 &self,
316 dates: Vec<NaiveDateTime>,
317 adjuster: PyAdjuster,
318 ) -> PyResult<Vec<NaiveDateTime>> {
319 Ok(self.adjusts(&dates, &adjuster.into()))
320 }
321
322 #[pyo3(name = "lag_bus_days")]
353 fn lag_bus_days_py(&self, date: NaiveDateTime, days: i32, settlement: bool) -> NaiveDateTime {
354 self.lag_bus_days(&date, days, settlement)
355 }
356
357 #[pyo3(name = "bus_date_range")]
370 fn bus_date_range_py(
371 &self,
372 start: NaiveDateTime,
373 end: NaiveDateTime,
374 ) -> PyResult<Vec<NaiveDateTime>> {
375 self.bus_date_range(&start, &end)
376 }
377
378 #[pyo3(name = "cal_date_range")]
391 fn cal_date_range_py(
392 &self,
393 start: NaiveDateTime,
394 end: NaiveDateTime,
395 ) -> PyResult<Vec<NaiveDateTime>> {
396 self.cal_date_range(&start, &end)
397 }
398
399 fn __getnewargs__(&self) -> PyResult<(Vec<NaiveDateTime>, Vec<u8>)> {
401 Ok((
402 self.clone().holidays.into_iter().collect(),
403 self.clone()
404 .week_mask
405 .into_iter()
406 .map(|x| x.num_days_from_monday() as u8)
407 .collect(),
408 ))
409 }
410
411 #[pyo3(name = "to_json")]
418 fn to_json_py(&self) -> PyResult<String> {
419 match DeserializedObj::Cal(self.clone()).to_json() {
420 Ok(v) => Ok(v),
421 Err(_) => Err(PyValueError::new_err("Failed to serialize `Cal` to JSON.")),
422 }
423 }
424
425 fn __eq__(&self, other: Calendar) -> bool {
427 match other {
428 Calendar::UnionCal(c) => *self == c,
429 Calendar::Cal(c) => *self == c,
430 Calendar::NamedCal(c) => *self == c,
431 }
432 }
433}
434
435#[pymethods]
436impl UnionCal {
437 #[new]
438 #[pyo3(signature = (calendars, settlement_calendars=None))]
439 fn new_py(calendars: Vec<Cal>, settlement_calendars: Option<Vec<Cal>>) -> PyResult<Self> {
440 Ok(UnionCal::new(calendars, settlement_calendars))
441 }
442
443 #[getter]
445 fn holidays(&self) -> PyResult<Vec<NaiveDateTime>> {
446 let mut set = self.calendars.iter().fold(IndexSet::new(), |acc, x| {
447 IndexSet::from_iter(acc.union(&x.holidays).cloned())
448 });
449 set.sort();
450 Ok(Vec::from_iter(set))
451 }
452
453 #[getter]
455 fn week_mask(&self) -> PyResult<HashSet<u8>> {
456 let mut s: HashSet<u8> = HashSet::new();
457 for cal in &self.calendars {
458 let ns = cal.week_mask()?;
459 s.extend(&ns);
460 }
461 Ok(s)
462 }
463
464 #[getter]
466 fn calendars(&self) -> Vec<Cal> {
467 self.calendars.clone()
468 }
469
470 #[getter]
472 fn settlement_calendars(&self) -> Option<Vec<Cal>> {
473 self.settlement_calendars.clone()
474 }
475
476 #[pyo3(name = "is_bus_day")]
480 fn is_bus_day_py(&self, date: NaiveDateTime) -> bool {
481 self.is_bus_day(&date)
482 }
483
484 #[pyo3(name = "is_non_bus_day")]
488 fn is_non_bus_day_py(&self, date: NaiveDateTime) -> bool {
489 self.is_non_bus_day(&date)
490 }
491
492 #[pyo3(name = "is_settlement")]
498 fn is_settlement_py(&self, date: NaiveDateTime) -> bool {
499 self.is_settlement(&date)
500 }
501
502 #[pyo3(name = "add_cal_days")]
506 fn add_cal_days_py(
507 &self,
508 date: NaiveDateTime,
509 days: i32,
510 adjuster: PyAdjuster,
511 ) -> PyResult<NaiveDateTime> {
512 Ok(self.add_cal_days(&date, days, &adjuster.into()))
513 }
514
515 #[pyo3(name = "add_bus_days")]
519 fn add_bus_days_py(
520 &self,
521 date: NaiveDateTime,
522 days: i32,
523 settlement: bool,
524 ) -> PyResult<NaiveDateTime> {
525 self.add_bus_days(&date, days, settlement)
526 }
527
528 #[pyo3(name = "add_months")]
532 fn add_months_py(
533 &self,
534 date: NaiveDateTime,
535 months: i32,
536 adjuster: PyAdjuster,
537 roll: Option<RollDay>,
538 ) -> NaiveDateTime {
539 let roll_ = match roll {
540 Some(val) => val,
541 None => RollDay::vec_from(&vec![date])[0],
542 };
543 let adjuster: Adjuster = adjuster.into();
544 adjuster.adjust(&roll_.uadd(&date, months), self)
545 }
546
547 #[pyo3(name = "adjust")]
551 fn adjust_py(&self, date: NaiveDateTime, adjuster: PyAdjuster) -> PyResult<NaiveDateTime> {
552 Ok(self.adjust(&date, &adjuster.into()))
553 }
554
555 #[pyo3(name = "adjusts")]
559 fn adjusts_py(
560 &self,
561 dates: Vec<NaiveDateTime>,
562 adjuster: PyAdjuster,
563 ) -> PyResult<Vec<NaiveDateTime>> {
564 Ok(self.adjusts(&dates, &adjuster.into()))
565 }
566
567 #[pyo3(name = "roll")]
571 fn roll_py(
572 &self,
573 date: NaiveDateTime,
574 modifier: &str,
575 settlement: bool,
576 ) -> PyResult<NaiveDateTime> {
577 let adjuster = get_roll_adjuster_from_str((&modifier.to_lowercase(), settlement))?;
578 Ok(self.adjust(&date, &adjuster))
579 }
580
581 #[pyo3(name = "lag_bus_days")]
585 fn lag_bus_days_py(&self, date: NaiveDateTime, days: i32, settlement: bool) -> NaiveDateTime {
586 self.lag_bus_days(&date, days, settlement)
587 }
588
589 #[pyo3(name = "bus_date_range")]
593 fn bus_date_range_py(
594 &self,
595 start: NaiveDateTime,
596 end: NaiveDateTime,
597 ) -> PyResult<Vec<NaiveDateTime>> {
598 self.bus_date_range(&start, &end)
599 }
600
601 #[pyo3(name = "cal_date_range")]
605 fn cal_date_range_py(
606 &self,
607 start: NaiveDateTime,
608 end: NaiveDateTime,
609 ) -> PyResult<Vec<NaiveDateTime>> {
610 self.cal_date_range(&start, &end)
611 }
612
613 fn __getnewargs__(&self) -> PyResult<(Vec<Cal>, Option<Vec<Cal>>)> {
615 Ok((self.calendars.clone(), self.settlement_calendars.clone()))
616 }
617
618 #[pyo3(name = "to_json")]
625 fn to_json_py(&self) -> PyResult<String> {
626 match DeserializedObj::UnionCal(self.clone()).to_json() {
627 Ok(v) => Ok(v),
628 Err(_) => Err(PyValueError::new_err(
629 "Failed to serialize `UnionCal` to JSON.",
630 )),
631 }
632 }
633
634 fn __eq__(&self, other: Calendar) -> bool {
636 match other {
637 Calendar::UnionCal(c) => *self == c,
638 Calendar::Cal(c) => *self == c,
639 Calendar::NamedCal(c) => *self == c,
640 }
641 }
642}
643
644#[pymethods]
645impl NamedCal {
646 #[new]
647 fn new_py(name: String) -> PyResult<Self> {
648 NamedCal::try_new(&name)
649 }
650
651 #[getter]
653 fn holidays(&self) -> PyResult<Vec<NaiveDateTime>> {
654 self.union_cal.holidays()
655 }
656
657 #[getter]
659 fn week_mask(&self) -> PyResult<HashSet<u8>> {
660 self.union_cal.week_mask()
661 }
662
663 #[getter]
665 fn name(&self) -> String {
666 self.name.clone()
667 }
668
669 #[getter]
671 fn union_cal(&self) -> UnionCal {
672 self.union_cal.clone()
673 }
674
675 #[pyo3(name = "is_bus_day")]
679 fn is_bus_day_py(&self, date: NaiveDateTime) -> bool {
680 self.is_bus_day(&date)
681 }
682
683 #[pyo3(name = "is_non_bus_day")]
687 fn is_non_bus_day_py(&self, date: NaiveDateTime) -> bool {
688 self.is_non_bus_day(&date)
689 }
690
691 #[pyo3(name = "is_settlement")]
697 fn is_settlement_py(&self, date: NaiveDateTime) -> bool {
698 self.is_settlement(&date)
699 }
700
701 #[pyo3(name = "add_cal_days")]
705 fn add_cal_days_py(
706 &self,
707 date: NaiveDateTime,
708 days: i32,
709 adjuster: PyAdjuster,
710 ) -> PyResult<NaiveDateTime> {
711 Ok(self.add_cal_days(&date, days, &adjuster.into()))
712 }
713
714 #[pyo3(name = "add_bus_days")]
718 fn add_bus_days_py(
719 &self,
720 date: NaiveDateTime,
721 days: i32,
722 settlement: bool,
723 ) -> PyResult<NaiveDateTime> {
724 self.add_bus_days(&date, days, settlement)
725 }
726
727 #[pyo3(name = "add_months")]
731 fn add_months_py(
732 &self,
733 date: NaiveDateTime,
734 months: i32,
735 adjuster: PyAdjuster,
736 roll: Option<RollDay>,
737 ) -> NaiveDateTime {
738 let roll_ = match roll {
739 Some(val) => val,
740 None => RollDay::vec_from(&vec![date])[0],
741 };
742 let adjuster: Adjuster = adjuster.into();
743 adjuster.adjust(&roll_.uadd(&date, months), self)
744 }
745
746 #[pyo3(name = "adjust")]
750 fn adjust_py(&self, date: NaiveDateTime, adjuster: PyAdjuster) -> PyResult<NaiveDateTime> {
751 Ok(self.adjust(&date, &adjuster.into()))
752 }
753
754 #[pyo3(name = "adjusts")]
758 fn adjusts_py(
759 &self,
760 dates: Vec<NaiveDateTime>,
761 adjuster: PyAdjuster,
762 ) -> PyResult<Vec<NaiveDateTime>> {
763 Ok(self.adjusts(&dates, &adjuster.into()))
764 }
765
766 #[pyo3(name = "roll")]
770 fn roll_py(
771 &self,
772 date: NaiveDateTime,
773 modifier: &str,
774 settlement: bool,
775 ) -> PyResult<NaiveDateTime> {
776 let adjuster = get_roll_adjuster_from_str((&modifier.to_lowercase(), settlement))?;
777 Ok(self.adjust(&date, &adjuster))
778 }
779
780 #[pyo3(name = "lag_bus_days")]
784 fn lag_bus_days_py(&self, date: NaiveDateTime, days: i32, settlement: bool) -> NaiveDateTime {
785 self.lag_bus_days(&date, days, settlement)
786 }
787
788 #[pyo3(name = "bus_date_range")]
792 fn bus_date_range_py(
793 &self,
794 start: NaiveDateTime,
795 end: NaiveDateTime,
796 ) -> PyResult<Vec<NaiveDateTime>> {
797 self.bus_date_range(&start, &end)
798 }
799
800 #[pyo3(name = "cal_date_range")]
804 fn cal_date_range_py(
805 &self,
806 start: NaiveDateTime,
807 end: NaiveDateTime,
808 ) -> PyResult<Vec<NaiveDateTime>> {
809 self.cal_date_range(&start, &end)
810 }
811
812 fn __getnewargs__(&self) -> PyResult<(String,)> {
814 Ok((self.name.clone(),))
815 }
816
817 #[pyo3(name = "to_json")]
824 fn to_json_py(&self) -> PyResult<String> {
825 match DeserializedObj::NamedCal(self.clone()).to_json() {
826 Ok(v) => Ok(v),
827 Err(_) => Err(PyValueError::new_err(
828 "Failed to serialize `NamedCal` to JSON.",
829 )),
830 }
831 }
832
833 fn __eq__(&self, other: Calendar) -> bool {
835 match other {
836 Calendar::UnionCal(c) => *self == c,
837 Calendar::Cal(c) => *self == c,
838 Calendar::NamedCal(c) => *self == c,
839 }
840 }
841}
842
843#[cfg(test)]
844mod tests {
845 use super::*;
846 use crate::scheduling::ndt;
847
848 #[test]
849 fn test_add_37_months() {
850 let cal = Cal::try_from_name("all").unwrap();
851
852 let dates = vec![
853 (ndt(2000, 1, 1), ndt(2003, 2, 1)),
854 (ndt(2000, 2, 1), ndt(2003, 3, 1)),
855 (ndt(2000, 3, 1), ndt(2003, 4, 1)),
856 (ndt(2000, 4, 1), ndt(2003, 5, 1)),
857 (ndt(2000, 5, 1), ndt(2003, 6, 1)),
858 (ndt(2000, 6, 1), ndt(2003, 7, 1)),
859 (ndt(2000, 7, 1), ndt(2003, 8, 1)),
860 (ndt(2000, 8, 1), ndt(2003, 9, 1)),
861 (ndt(2000, 9, 1), ndt(2003, 10, 1)),
862 (ndt(2000, 10, 1), ndt(2003, 11, 1)),
863 (ndt(2000, 11, 1), ndt(2003, 12, 1)),
864 (ndt(2000, 12, 1), ndt(2004, 1, 1)),
865 ];
866 for i in 0..12 {
867 assert_eq!(
868 cal.add_months_py(
869 dates[i].0,
870 37,
871 Adjuster::FollowingSettle {}.into(),
872 Some(RollDay::Day(1)),
873 ),
874 dates[i].1
875 )
876 }
877 }
878
879 #[test]
880 fn test_sub_37_months() {
881 let cal = Cal::try_from_name("all").unwrap();
882
883 let dates = vec![
884 (ndt(2000, 1, 1), ndt(1996, 12, 1)),
885 (ndt(2000, 2, 1), ndt(1997, 1, 1)),
886 (ndt(2000, 3, 1), ndt(1997, 2, 1)),
887 (ndt(2000, 4, 1), ndt(1997, 3, 1)),
888 (ndt(2000, 5, 1), ndt(1997, 4, 1)),
889 (ndt(2000, 6, 1), ndt(1997, 5, 1)),
890 (ndt(2000, 7, 1), ndt(1997, 6, 1)),
891 (ndt(2000, 8, 1), ndt(1997, 7, 1)),
892 (ndt(2000, 9, 1), ndt(1997, 8, 1)),
893 (ndt(2000, 10, 1), ndt(1997, 9, 1)),
894 (ndt(2000, 11, 1), ndt(1997, 10, 1)),
895 (ndt(2000, 12, 1), ndt(1997, 11, 1)),
896 ];
897 for i in 0..12 {
898 assert_eq!(
899 cal.add_months_py(
900 dates[i].0,
901 -37,
902 Adjuster::FollowingSettle {}.into(),
903 Some(RollDay::Day(1)),
904 ),
905 dates[i].1
906 )
907 }
908 }
909
910 #[test]
911 fn test_add_months_py_roll() {
912 let cal = Cal::try_from_name("all").unwrap();
913 let roll = vec![
914 (RollDay::Day(7), ndt(1998, 3, 7), ndt(1996, 12, 7)),
915 (RollDay::Day(21), ndt(1998, 3, 21), ndt(1996, 12, 21)),
916 (RollDay::Day(31), ndt(1998, 3, 31), ndt(1996, 12, 31)),
917 (RollDay::Day(1), ndt(1998, 3, 1), ndt(1996, 12, 1)),
918 (RollDay::IMM(), ndt(1998, 3, 18), ndt(1996, 12, 18)),
919 ];
920 for i in 0..5 {
921 assert_eq!(
922 cal.add_months_py(
923 roll[i].1,
924 -15,
925 Adjuster::FollowingSettle {}.into(),
926 Some(roll[i].0)
927 ),
928 roll[i].2
929 );
930 }
931 }
932
933 #[test]
934 fn test_add_months_roll_invalid_days() {
935 let cal = Cal::try_from_name("all").unwrap();
936 let roll = vec![
937 (RollDay::Day(21), ndt(1996, 12, 21)),
938 (RollDay::Day(31), ndt(1996, 12, 31)),
939 (RollDay::Day(1), ndt(1996, 12, 1)),
940 (RollDay::IMM(), ndt(1996, 12, 18)),
941 ];
942 for i in 0..4 {
943 assert_eq!(
944 roll[i].1,
945 cal.add_months_py(
946 ndt(1998, 3, 7),
947 -15,
948 Adjuster::FollowingSettle {}.into(),
949 Some(roll[i].0),
950 ),
951 );
952 }
953 }
954
955 #[test]
956 fn test_add_months_modifier() {
957 let cal = Cal::try_from_name("bus").unwrap();
958 let modi = vec![
959 (Adjuster::Actual {}, ndt(2023, 9, 30)), (Adjuster::FollowingSettle {}, ndt(2023, 10, 2)), (Adjuster::ModifiedFollowingSettle {}, ndt(2023, 9, 29)), (Adjuster::PreviousSettle {}, ndt(2023, 9, 29)), (Adjuster::ModifiedPreviousSettle {}, ndt(2023, 9, 29)), ];
965 for i in 0..4 {
966 assert_eq!(
967 cal.add_months_py(
968 ndt(2023, 8, 31),
969 1,
970 modi[i].0.into(),
971 Some(RollDay::Day(31))
972 ),
973 modi[i].1
974 );
975 }
976 }
977
978 #[test]
979 fn test_add_months_modifier_p() {
980 let cal = Cal::try_from_name("bus").unwrap();
981 let modi = vec![
982 (Adjuster::Actual {}, ndt(2023, 7, 1)), (Adjuster::FollowingSettle {}, ndt(2023, 7, 3)), (Adjuster::ModifiedFollowingSettle {}, ndt(2023, 7, 3)), (Adjuster::PreviousSettle {}, ndt(2023, 6, 30)), (Adjuster::ModifiedPreviousSettle {}, ndt(2023, 7, 3)), ];
988 for i in 0..4 {
989 assert_eq!(
990 cal.add_months_py(ndt(2023, 8, 1), -1, modi[i].0.into(), Some(RollDay::Day(1))),
991 modi[i].1
992 );
993 }
994 }
995}