use nalgebra::SVector; /// The basic ODE object that will be passed around. The type (T) and the size (D) will be /// determined upon creation of the object #[derive(Clone, Copy)] pub struct ODE<'a, const D: usize, P> { pub f: &'a dyn Fn(f64, SVector, &P) -> SVector, pub y: SVector, pub t: f64, pub params: P, pub t0: f64, pub t_end: f64, pub h: f64, pub finished: bool, } impl<'a, const D: usize, P> ODE<'a, D, P> { pub fn new( f: &'a (dyn Fn(f64, SVector, &P) -> SVector), t0: f64, t_end: f64, y0: SVector, params: P, ) -> Self { Self { f, y: y0, t: t0, params, t0, t_end, h: 0.001, finished: false, } } } #[cfg(test)] mod tests { use super::*; use nalgebra::Vector3; #[test] fn test_ode_creation() { type Params = (); fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { -y } let y0 = Vector3::new(1.0, 0.0, 0.0); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); assert!((ode.f)(0.0, y0, &()) == Vector3::new(-1.0, 0.0, 0.0)); assert!(ode.y == Vector3::new(1.0, 0.0, 0.0)); assert!(ode.t == 0.0); assert!(!ode.finished); assert!(ode.t_end == 10.0); } #[test] fn test_ode_with_params() { type Params = (f64, bool); let params = (34.0, true); fn derivative(t: f64, y: Vector3, p: &Params) -> Vector3 { if p.1 { -y } else { y * t } } let y0 = Vector3::new(1.0, 0.0, 0.0); let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); assert!((ode.f)(0.0, y0, ¶ms) == Vector3::new(-1.0, 0.0, 0.0)); assert!(ode.y == Vector3::new(1.0, 0.0, 0.0)); assert!(ode.t == 0.0); assert!(!ode.finished); assert!(ode.t_end == 10.0); } }