1use crate::dual::dual::{Dual, Dual2, Gradient1, Gradient2, Vars};
4use crate::dual::dual_ops::math_funcs::MathFuncs;
5use crate::dual::enums::{ADOrder, Number};
6use crate::json::json_py::DeserializedObj;
7use crate::json::JSON;
8use num_traits::{Pow, Signed};
9use numpy::{Element, PyArray1, PyArray2, PyArrayDescr, ToPyArray};
10use pyo3::exceptions::{PyTypeError, PyValueError};
11use pyo3::prelude::*;
12use std::sync::Arc;
13
14unsafe impl Element for Dual {
15 const IS_COPY: bool = false;
16 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
17 PyArrayDescr::object(py)
18 }
19
20 fn clone_ref(&self, _py: Python<'_>) -> Self {
21 self.clone()
22 }
23}
24unsafe impl Element for Dual2 {
25 const IS_COPY: bool = false;
26 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
27 PyArrayDescr::object(py)
28 }
29
30 fn clone_ref(&self, _py: Python<'_>) -> Self {
31 self.clone()
32 }
33}
34
35#[pymethods]
44impl ADOrder {
45 #[new]
47 fn new_py(ad: u8) -> PyResult<ADOrder> {
48 match ad {
49 0_u8 => Ok(ADOrder::Zero),
50 1_u8 => Ok(ADOrder::One),
51 2_u8 => Ok(ADOrder::Two),
52 _ => Err(PyValueError::new_err("unreachable code on ADOrder pickle.")),
53 }
54 }
55 fn __getnewargs__<'py>(&self) -> PyResult<(u8,)> {
56 match self {
57 ADOrder::Zero => Ok((0_u8,)),
58 ADOrder::One => Ok((1_u8,)),
59 ADOrder::Two => Ok((2_u8,)),
60 }
61 }
62}
63
64#[pymethods]
65impl Dual {
66 #[new]
67 fn new_py(real: f64, vars: Vec<String>, dual: Vec<f64>) -> PyResult<Self> {
68 Dual::try_new(real, vars, dual)
69 }
70
71 #[staticmethod]
110 fn vars_from(other: &Dual, real: f64, vars: Vec<String>, dual: Vec<f64>) -> PyResult<Self> {
111 Dual::try_new_from(other, real, vars, dual)
112 }
113
114 #[getter]
116 #[pyo3(name = "real")]
117 fn real_py(&self) -> PyResult<f64> {
118 Ok(self.real())
119 }
120
121 #[getter]
123 #[pyo3(name = "vars")]
124 fn vars_py(&self) -> PyResult<Vec<&String>> {
125 Ok(Vec::from_iter(self.vars().iter()))
126 }
127
128 #[getter]
130 #[pyo3(name = "dual")]
131 fn dual_py<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
132 Ok(self.dual().to_pyarray(py))
133 }
134
135 #[getter]
137 #[pyo3(name = "dual2")]
138 fn dual2_py<'py>(&'py self, _py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
139 Err(PyValueError::new_err(
140 "`Dual` variable cannot possess `dual2` attribute.",
141 ))
142 }
143
144 #[pyo3(name = "grad1")]
155 fn grad1<'py>(
156 &'py self,
157 py: Python<'py>,
158 vars: Vec<String>,
159 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
160 Ok(self.gradient1(vars).to_pyarray(py))
161 }
162
163 #[pyo3(name = "grad2")]
165 fn grad2<'py>(
166 &'py self,
167 _py: Python<'py>,
168 _vars: Vec<String>,
169 ) -> PyResult<Bound<'py, PyArray2<f64>>> {
170 Err(PyValueError::new_err(
171 "Cannot evaluate second order derivative on a Dual.",
172 ))
173 }
174
175 #[pyo3(name = "ptr_eq")]
186 fn ptr_eq_py(&self, other: &Dual) -> PyResult<bool> {
187 Ok(Arc::ptr_eq(self.vars(), other.vars()))
188 }
189
190 fn __repr__(&self) -> PyResult<String> {
191 let mut _vars = Vec::from_iter(self.vars().iter().take(3).map(String::as_str)).join(", ");
192 let mut _dual =
193 Vec::from_iter(self.dual().iter().take(3).map(|x| format!("{:.1}", x))).join(", ");
194 if self.vars().len() > 3 {
195 _vars.push_str(", ...");
196 _dual.push_str(", ...");
197 }
198 let fs = format!("<Dual: {:.6}, ({}), [{}]>", self.real(), _vars, _dual);
199 Ok(fs)
200 }
201
202 fn __eq__(&self, other: Number) -> PyResult<bool> {
203 match other {
204 Number::Dual(d) => Ok(d.eq(self)),
205 Number::F64(f) => Ok(Dual::new(f, Vec::new()).eq(self)),
206 Number::Dual2(_) => Err(PyTypeError::new_err(
207 "Cannot compare Dual with incompatible type (Dual2).",
208 )),
209 }
210 }
211
212 fn __lt__(&self, other: Number) -> PyResult<bool> {
213 match other {
214 Number::Dual(d) => Ok(self < &d),
215 Number::F64(f) => Ok(self < &f),
216 Number::Dual2(_) => Err(PyTypeError::new_err(
217 "Cannot compare Dual with incompatible type (Dual2).",
218 )),
219 }
220 }
221
222 fn __le__(&self, other: Number) -> PyResult<bool> {
223 match other {
224 Number::Dual(d) => Ok(self <= &d),
225 Number::F64(f) => Ok(self <= &f),
226 Number::Dual2(_) => Err(PyTypeError::new_err(
227 "Cannot compare Dual with incompatible type (Dual2).",
228 )),
229 }
230 }
231
232 fn __gt__(&self, other: Number) -> PyResult<bool> {
233 match other {
234 Number::Dual(d) => Ok(self > &d),
235 Number::F64(f) => Ok(self > &f),
236 Number::Dual2(_) => Err(PyTypeError::new_err(
237 "Cannot compare Dual with incompatible type (Dual2).",
238 )),
239 }
240 }
241
242 fn __ge__(&self, other: Number) -> PyResult<bool> {
243 match other {
244 Number::Dual(d) => Ok(self >= &d),
245 Number::F64(f) => Ok(self >= &f),
246 Number::Dual2(_) => Err(PyTypeError::new_err(
247 "Cannot compare Dual with incompatible type (Dual2).",
248 )),
249 }
250 }
251
252 fn __neg__(&self) -> Self {
253 -self
254 }
255
256 fn __add__(&self, other: Number) -> PyResult<Self> {
257 match other {
258 Number::Dual(d) => Ok(self + d),
259 Number::F64(f) => Ok(self + f),
260 Number::Dual2(_) => Err(PyTypeError::new_err(
261 "Dual operation with incompatible type (Dual2).",
262 )),
263 }
264 }
265
266 fn __radd__(&self, other: Number) -> PyResult<Self> {
267 match other {
268 Number::Dual(d) => Ok(self + d),
269 Number::F64(f) => Ok(self + f),
270 Number::Dual2(_) => Err(PyTypeError::new_err(
271 "Dual operation with incompatible type (Dual2).",
272 )),
273 }
274 }
275
276 fn __sub__(&self, other: Number) -> PyResult<Self> {
277 match other {
278 Number::Dual(d) => Ok(self - d),
279 Number::F64(f) => Ok(self - f),
280 Number::Dual2(_) => Err(PyTypeError::new_err(
281 "Dual operation with incompatible type (Dual2).",
282 )),
283 }
284 }
285
286 fn __rsub__(&self, other: Number) -> PyResult<Self> {
287 match other {
288 Number::Dual(d) => Ok(d - self),
289 Number::F64(f) => Ok(f - self),
290 Number::Dual2(_) => Err(PyTypeError::new_err(
291 "Dual operation with incompatible type (Dual2).",
292 )),
293 }
294 }
295
296 fn __mul__(&self, other: Number) -> PyResult<Self> {
297 match other {
298 Number::Dual(d) => Ok(self * d),
299 Number::F64(f) => Ok(self * f),
300 Number::Dual2(_) => Err(PyTypeError::new_err(
301 "Dual operation with incompatible type (Dual2).",
302 )),
303 }
304 }
305
306 fn __rmul__(&self, other: Number) -> PyResult<Self> {
307 match other {
308 Number::Dual(d) => Ok(d * self),
309 Number::F64(f) => Ok(f * self),
310 Number::Dual2(_) => Err(PyTypeError::new_err(
311 "Dual operation with incompatible type (Dual2).",
312 )),
313 }
314 }
315
316 fn __truediv__(&self, other: Number) -> PyResult<Self> {
317 match other {
318 Number::Dual(d) => Ok(self / d),
319 Number::F64(f) => Ok(self / f),
320 Number::Dual2(_) => Err(PyTypeError::new_err(
321 "Dual operation with incompatible type (Dual2).",
322 )),
323 }
324 }
325
326 fn __rtruediv__(&self, other: Number) -> PyResult<Self> {
327 match other {
328 Number::Dual(d) => Ok(d / self),
329 Number::F64(f) => Ok(f / self),
330 Number::Dual2(_) => Err(PyTypeError::new_err(
331 "Dual operation with incompatible type (Dual2).",
332 )),
333 }
334 }
335
336 fn __pow__(&self, power: Number, modulo: Option<i32>) -> PyResult<Self> {
337 if modulo.unwrap_or(0) != 0 {
338 panic!("Power function with mod not available for Dual.")
339 }
340 match power {
341 Number::F64(f) => Ok(self.clone().pow(f)),
342 Number::Dual(d_) => Ok(self.pow(d_)),
343 Number::Dual2(_) => Err(PyTypeError::new_err(
344 "Power operation does not permit Dual/Dual2 type crossing.",
345 )),
346 }
347 }
348
349 fn __rpow__(&self, other: Number, modulo: Option<i32>) -> PyResult<Self> {
350 if modulo.unwrap_or(0) != 0 {
351 panic!("Power function with mod not available for Dual.")
352 }
353 match other {
354 Number::F64(f) => Ok(f.pow(self)),
355 Number::Dual(d_) => Ok(d_.pow(self)),
356 Number::Dual2(_) => Err(PyTypeError::new_err(
357 "Power operation does not permit Dual/Dual2 type crossing.",
358 )),
359 }
360 }
361
362 fn __exp__(&self) -> Self {
363 self.exp()
364 }
365
366 fn __abs__(&self) -> Self {
367 self.abs()
368 }
369
370 fn __log__(&self) -> Self {
371 self.log()
372 }
373
374 fn __norm_cdf__(&self) -> Self {
375 self.norm_cdf()
376 }
377
378 fn __norm_inv_cdf__(&self) -> Self {
379 self.inv_norm_cdf()
380 }
381
382 fn __float__(&self) -> f64 {
383 self.real()
384 }
385
386 #[pyo3(name = "to_json")]
393 fn to_json_py(&self) -> PyResult<String> {
394 match DeserializedObj::Dual(self.clone()).to_json() {
395 Ok(v) => Ok(v),
396 Err(_) => Err(PyValueError::new_err("Failed to serialize `Dual` to JSON.")),
397 }
398 }
399
400 pub fn __getnewargs__(&self) -> PyResult<(f64, Vec<String>, Vec<f64>)> {
402 Ok((
403 self.real,
404 self.vars().iter().cloned().collect(),
405 self.dual.to_vec(),
406 ))
407 }
408
409 #[pyo3(name = "to_dual2")]
411 fn to_dual2_py(&self) -> Dual2 {
412 self.clone().into()
413 }
414}
415
416#[pymethods]
417impl Dual2 {
418 #[new]
420 pub fn new_py(real: f64, vars: Vec<String>, dual: Vec<f64>, dual2: Vec<f64>) -> PyResult<Self> {
421 Dual2::try_new(real, vars, dual, dual2)
422 }
423
424 #[staticmethod]
458 pub fn vars_from(
459 other: &Dual2,
460 real: f64,
461 vars: Vec<String>,
462 dual: Vec<f64>,
463 dual2: Vec<f64>,
464 ) -> PyResult<Self> {
465 Dual2::try_new_from(other, real, vars, dual, dual2)
466 }
467
468 #[getter]
470 #[pyo3(name = "real")]
471 fn real_py(&self) -> PyResult<f64> {
472 Ok(self.real)
473 }
474
475 #[getter]
477 #[pyo3(name = "vars")]
478 fn vars_py(&self) -> PyResult<Vec<&String>> {
479 Ok(Vec::from_iter(self.vars.iter()))
480 }
481
482 #[getter]
484 #[pyo3(name = "dual")]
485 fn dual_py<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
486 Ok(self.dual.to_pyarray(py))
487 }
488
489 #[getter]
491 #[pyo3(name = "dual2")]
492 fn dual2_py<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyArray2<f64>>> {
493 Ok(self.dual2.to_pyarray(py))
494 }
495
496 #[pyo3(name = "grad1")]
507 fn grad1_py<'py>(
508 &'py self,
509 py: Python<'py>,
510 vars: Vec<String>,
511 ) -> PyResult<Bound<'py, PyArray1<f64>>> {
512 Ok(self.gradient1(vars).to_pyarray(py))
513 }
514
515 #[pyo3(name = "grad2")]
526 fn grad2_py<'py>(
527 &'py self,
528 py: Python<'py>,
529 vars: Vec<String>,
530 ) -> PyResult<Bound<'py, PyArray2<f64>>> {
531 Ok(self.gradient2(vars).to_pyarray(py))
532 }
533
534 #[pyo3(name = "grad1_manifold")]
548 fn grad1_manifold_py<'py>(
549 &'py self,
550 _py: Python<'py>,
551 vars: Vec<String>,
552 ) -> PyResult<Vec<Dual2>> {
553 let out = self.gradient1_manifold(vars);
554 Ok(out.into_raw_vec_and_offset().0)
555 }
556
557 #[pyo3(name = "ptr_eq")]
560 fn ptr_eq_py(&self, other: &Dual2) -> PyResult<bool> {
561 Ok(self.ptr_eq(other))
562 }
563
564 fn __repr__(&self) -> PyResult<String> {
565 let mut _vars = Vec::from_iter(self.vars.iter().take(3).map(String::as_str)).join(", ");
566 let mut _dual =
567 Vec::from_iter(self.dual.iter().take(3).map(|x| format!("{:.1}", x))).join(", ");
568 if self.vars.len() > 3 {
569 _vars.push_str(", ...");
570 _dual.push_str(", ...");
571 }
572 let fs = format!(
573 "<Dual2: {:.6}, ({}), [{}], [[...]]>",
574 self.real, _vars, _dual
575 );
576 Ok(fs)
577 }
578
579 fn __eq__(&self, other: Number) -> PyResult<bool> {
580 match other {
581 Number::Dual2(d) => Ok(d.eq(self)),
582 Number::F64(f) => Ok(Dual2::new(f, Vec::new()).eq(self)),
583 Number::Dual(_d) => Err(PyTypeError::new_err(
584 "Cannot compare Dual2 with incompatible type (Dual).",
585 )),
586 }
587 }
588
589 fn __lt__(&self, other: Number) -> PyResult<bool> {
590 match other {
591 Number::Dual2(d) => Ok(self < &d),
592 Number::F64(f) => Ok(self < &f),
593 Number::Dual(_d) => Err(PyTypeError::new_err(
594 "Cannot compare Dual2 with incompatible type (Dual).",
595 )),
596 }
597 }
598
599 fn __le__(&self, other: Number) -> PyResult<bool> {
600 match other {
601 Number::Dual2(d) => Ok(self <= &d),
602 Number::F64(f) => Ok(self <= &f),
603 Number::Dual(_d) => Err(PyTypeError::new_err(
604 "Cannot compare Dual2 with incompatible type (Dual).",
605 )),
606 }
607 }
608
609 fn __gt__(&self, other: Number) -> PyResult<bool> {
610 match other {
611 Number::Dual2(d) => Ok(self > &d),
612 Number::F64(f) => Ok(self > &f),
613 Number::Dual(_d) => Err(PyTypeError::new_err(
614 "Cannot compare Dual2 with incompatible type (Dual).",
615 )),
616 }
617 }
618
619 fn __ge__(&self, other: Number) -> PyResult<bool> {
620 match other {
621 Number::Dual2(d) => Ok(self >= &d),
622 Number::F64(f) => Ok(self >= &f),
623 Number::Dual(_d) => Err(PyTypeError::new_err(
624 "Cannot compare Dual2 with incompatible type (Dual).",
625 )),
626 }
627 }
628
629 fn __neg__(&self) -> Self {
630 -self
631 }
632
633 fn __add__(&self, other: Number) -> PyResult<Self> {
634 match other {
635 Number::Dual2(d) => Ok(self + d),
636 Number::F64(f) => Ok(self + f),
637 Number::Dual(_d) => Err(PyTypeError::new_err(
638 "Dual2 operation with incompatible type (Dual).",
639 )),
640 }
641 }
642
643 fn __radd__(&self, other: Number) -> PyResult<Self> {
644 match other {
645 Number::Dual2(d) => Ok(self + d),
646 Number::F64(f) => Ok(self + f),
647 Number::Dual(_d) => Err(PyTypeError::new_err(
648 "Dual2 operation with incompatible type (Dual).",
649 )),
650 }
651 }
652
653 fn __sub__(&self, other: Number) -> PyResult<Self> {
654 match other {
655 Number::Dual2(d) => Ok(self - d),
656 Number::F64(f) => Ok(self - f),
657 Number::Dual(_d) => Err(PyTypeError::new_err(
658 "Dual2 operation with incompatible type (Dual).",
659 )),
660 }
661 }
662
663 fn __rsub__(&self, other: Number) -> PyResult<Self> {
664 match other {
665 Number::Dual2(d) => Ok(d - self),
666 Number::F64(f) => Ok(f - self),
667 Number::Dual(_d) => Err(PyTypeError::new_err(
668 "Dual2 operation with incompatible type (Dual).",
669 )),
670 }
671 }
672
673 fn __mul__(&self, other: Number) -> PyResult<Self> {
674 match other {
675 Number::Dual2(d) => Ok(self * d),
676 Number::F64(f) => Ok(self * f),
677 Number::Dual(_d) => Err(PyTypeError::new_err(
678 "Dual2 operation with incompatible type (Dual).",
679 )),
680 }
681 }
682
683 fn __rmul__(&self, other: Number) -> PyResult<Self> {
684 match other {
685 Number::Dual2(d) => Ok(d * self),
686 Number::F64(f) => Ok(f * self),
687 Number::Dual(_d) => Err(PyTypeError::new_err(
688 "Dual2 operation with incompatible type (Dual).",
689 )),
690 }
691 }
692
693 fn __truediv__(&self, other: Number) -> PyResult<Self> {
694 match other {
695 Number::Dual2(d) => Ok(self / d),
696 Number::F64(f) => Ok(self / f),
697 Number::Dual(_d) => Err(PyTypeError::new_err(
698 "Dual2 operation with incompatible type (Dual).",
699 )),
700 }
701 }
702
703 fn __rtruediv__(&self, other: Number) -> PyResult<Self> {
704 match other {
705 Number::Dual2(d) => Ok(d / self),
706 Number::F64(f) => Ok(f / self),
707 Number::Dual(_d) => Err(PyTypeError::new_err(
708 "Dual2 operation with incompatible type (Dual).",
709 )),
710 }
711 }
712
713 fn __pow__(&self, power: Number, modulo: Option<i32>) -> PyResult<Self> {
714 if modulo.unwrap_or(0) != 0 {
715 panic!("Power function with mod not available for Dual.")
716 }
717 match power {
718 Number::F64(f) => Ok(self.clone().pow(f)),
719 Number::Dual(_d) => Err(PyTypeError::new_err(
720 "Power operation does not permit Dual/Dual2 type crossing.",
721 )),
722 Number::Dual2(d) => Ok(self.pow(d)),
723 }
724 }
725
726 fn __rpow__(&self, other: Number, modulo: Option<i32>) -> PyResult<Self> {
727 if modulo.unwrap_or(0) != 0 {
728 panic!("Power function with mod not available for Dual2.")
729 }
730 match other {
731 Number::F64(f) => Ok(f.pow(self)),
732 Number::Dual(_d) => Err(PyTypeError::new_err(
733 "Power operation does not permit Dual/Dual2 type crossing.",
734 )),
735 Number::Dual2(d_) => Ok(d_.pow(self)),
736 }
737 }
738
739 fn __exp__(&self) -> Self {
740 self.exp()
741 }
742
743 fn __abs__(&self) -> Self {
744 self.abs()
745 }
746
747 fn __log__(&self) -> Self {
748 self.log()
749 }
750
751 fn __norm_cdf__(&self) -> Self {
752 self.norm_cdf()
753 }
754
755 fn __norm_inv_cdf__(&self) -> Self {
756 self.inv_norm_cdf()
757 }
758
759 fn __float__(&self) -> f64 {
760 self.real
761 }
762
763 #[pyo3(name = "to_json")]
770 fn to_json_py(&self) -> PyResult<String> {
771 match DeserializedObj::Dual2(self.clone()).to_json() {
772 Ok(v) => Ok(v),
773 Err(_) => Err(PyValueError::new_err(
774 "Failed to serialize `Dual2` to JSON.",
775 )),
776 }
777 }
778
779 fn __getnewargs__(&self) -> PyResult<(f64, Vec<String>, Vec<f64>, Vec<f64>)> {
781 Ok((
782 self.real,
783 self.vars().iter().cloned().collect(),
784 self.dual.to_vec(),
785 self.dual2.clone().into_raw_vec_and_offset().0,
786 ))
787 }
788
789 #[pyo3(name = "to_dual")]
791 fn to_dual_py(&self) -> Dual {
792 self.clone().into()
793 }
794}