diff --git a/readme.md b/readme.md index 0df1ac8..f9c8abc 100644 --- a/readme.md +++ b/readme.md @@ -39,20 +39,30 @@ use differential_equations::controller::PIController; use differential_equations::callback::stop; use differential_equations::problem::*; -// Define the system -fn derivative(_t: f64, y: Vector3) -> Vector3 { y } +// Define the system (parameters, derivative, and initial state) +type Params = (f64, bool); +let params = (34.0, true); + +fn derivative(t: f64, y: Vector3, p: &Params) -> Vector3 { + if p.1 { -y } else { y * t } +} + let y0 = Vector3::new(1.0, 1.0, 1.0); -let ode = ODE::new(&derivative, 0.0, 10.0, y0); +// Set up the problem (ODE, Integrator, Controller, and Callbacks) +let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); let controller = PIController::default(); let value_too_high = Callback { - event: &|_: f64, y: SVector| { 10.0 - y[0] }, + event: &|_: f64, y: Vector3, _: &Params| { 10.0 - y[0] }, effect: &stop, }; +// Solve the problem let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high); let solution = problem.solve(); + +// Can interpolate solutions to whatever you want let interpolated_answer = solution.interpolate(8.2); ``` diff --git a/src/callback.rs b/src/callback.rs index 58591db..b559bac 100644 --- a/src/callback.rs +++ b/src/callback.rs @@ -1,23 +1,22 @@ use nalgebra::SVector; + use super::ode::ODE; /// A function that takes in a time and a state and outputs a single float value /// /// The integration solver will check this function for zero crossings #[derive(Clone, Copy)] -pub struct Callback<'a, const D: usize> { +pub struct Callback<'a, const D: usize, P> { /// The function to check for zero crossings - pub event: &'a dyn Fn(f64, SVector) -> f64, + pub event: &'a dyn Fn(f64, SVector, &P) -> f64, /// The function to change the ODE - pub effect: &'a dyn Fn(ODE) -> ODE, + pub effect: &'a dyn Fn(&mut ODE) -> (), } /// A convenience function for stopping the integration -pub fn stop(ode: ODE) -> ODE { - let mut new_ode = ode.clone(); - new_ode.t_end = new_ode.t; - new_ode +pub fn stop<'a, const D: usize, P>(ode: &mut ODE) -> () { + ode.t_end = ode.t; } #[cfg(test)] @@ -26,8 +25,9 @@ mod tests { #[test] fn test_basic_callbacks() { + type Params = (); let _value_too_high = Callback { - event: &|_: f64, y: SVector| { 10.0 - y[0] }, + event: &|_: f64, y: SVector, _p: &Params| { 10.0 - y[0] }, effect: &stop, }; } diff --git a/src/integrator/dormand_prince.rs b/src/integrator/dormand_prince.rs index a1cc201..9940111 100644 --- a/src/integrator/dormand_prince.rs +++ b/src/integrator/dormand_prince.rs @@ -95,13 +95,13 @@ where const ADAPTIVE: bool = true; const DENSE: bool = true; - fn step(&self, ode: &ODE, h: f64) -> (SVector, Option, Option>>) { + fn step

(&self, ode: &ODE, h: f64) -> (SVector, Option, Option>>) { let mut k: Vec> = vec![SVector::::zeros(); Self::STAGES]; let mut next_y = ode.y.clone(); let mut err = SVector::::zeros(); let mut rcont5 = SVector::::zeros(); // Do the first of the summations - k[0] = (ode.f)(ode.t, ode.y); + k[0] = (ode.f)(ode.t, ode.y, &ode.params); next_y += k[0] * Self::B[0] * h; err += k[0] * (Self::B[0] - Self::B[Self::STAGES]) * h; let rcont1 = ode.y; @@ -113,7 +113,7 @@ where for j in 0..i { y_term += k[j] * Self::A[( i * (i - 1) ) / 2 + j]; } - k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h); + k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h, &ode.params); // Use that and bis to calculate the y and error terms next_y += k[i] * h * Self::B[i]; diff --git a/src/integrator/mod.rs b/src/integrator/mod.rs index 4af2ab7..8d3e1c7 100644 --- a/src/integrator/mod.rs +++ b/src/integrator/mod.rs @@ -13,7 +13,7 @@ pub trait Integrator { const DENSE: bool; /// Returns a new y value, then possibly an error value, and possibly a dense output /// coefficient set - fn step(&self, ode: &ODE, h: f64) -> (SVector, Option, Option>>); + fn step

(&self, ode: &ODE, h: f64) -> (SVector, Option, Option>>); fn interpolate(&self, t_start: f64, t_end: f64, dense: &Vec>, t: f64) -> SVector; } @@ -28,10 +28,11 @@ mod tests { #[test] fn test_dopri() { - fn derivative(_t: f64, y: Vector3) -> Vector3 { y } + type Params = (); + fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { y } let y0 = Vector3::new(1.0, 1.0, 1.0); - let mut ode = ODE::new(&derivative, 0.0, 4.0, y0); + let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ()); let dp45 = DormandPrince45::new(1e-12_f64, 1e-4_f64); diff --git a/src/lib.rs b/src/lib.rs index 51fba1f..d218abc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,24 +9,57 @@ pub mod problem; #[cfg(test)] mod tests { - use nalgebra::Vector6; + use nalgebra::{Vector3, Vector6}; use approx::assert_relative_eq; use crate::integrator::dormand_prince::DormandPrince45; use crate::controller::PIController; + use crate::callback::{Callback, stop}; use crate::ode::ODE; use crate::problem::Problem; use std::f64::consts::PI; #[test] - fn test_orbit() { + fn test_readme() { + // Define the system (parameters, derivative, and initial state) + type Params = (f64, bool); + let params = (34.0, true); + fn derivative(t: f64, y: Vector3, p: &Params) -> Vector3 { + if p.1 { -y } else { y * t } + } + + let y0 = Vector3::new(1.0, 1.0, 1.0); + + // Set up the problem (ODE, Integrator, Controller, and Callbacks) + let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); + let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); + let controller = PIController::default(); + + let value_too_high = Callback { + event: &|_: f64, y: Vector3, _: &Params| { 10.0 - y[0] }, + effect: &stop, + }; + + // Solve the problem + let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high); + let solution = problem.solve(); + + // Can interpolate solutions to whatever you want + let _interpolated_answer = solution.interpolate(8.2); + } + + #[test] + fn test_orbit() { // Calculate one period let a = 6.7781363e6_f64; - let period = 2.0 * PI * (a.powi(3)/3.98600441500000e14).sqrt(); + let mu = 3.98600441500000e14; + let period = 2.0 * PI * (a.powi(3)/mu).sqrt(); // Set up the system - fn derivative(_t: f64, state: Vector6) -> Vector6 { - let acc = -(3.98600441500000e14 * state.fixed_rows::<3>(0)) / (state.fixed_rows::<3>(0).norm().powi(3)); + type Params = (f64,); + let params = (mu,); + fn derivative(_t: f64, state: Vector6, p: &Params) -> Vector6 { + let acc = -(p.0 * state.fixed_rows::<3>(0)) / (state.fixed_rows::<3>(0).norm().powi(3)); Vector6::new(state[3], state[4], state[5], acc[0], acc[1], acc[2]) } let y0 = Vector6::new( @@ -39,7 +72,7 @@ mod tests { ); // Integrate - let ode = ODE::new(&derivative, 0.0, 10.0*period, y0); + let ode = ODE::new(&derivative, 0.0, 10.0*period, y0, params); let dp45 = DormandPrince45::new(1e-12_f64, 1e-12_f64); let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); diff --git a/src/ode.rs b/src/ode.rs index 3cd8c52..a5b366f 100644 --- a/src/ode.rs +++ b/src/ode.rs @@ -3,22 +3,30 @@ use nalgebra::SVector; /// The basic ODE object that will be passed around. The type (T) and the size (D) will be /// determined upon creation of the object #[derive(Clone, Copy)] -pub struct ODE<'a, const D: usize> { - pub f: &'a dyn Fn(f64, SVector) -> SVector, +pub struct ODE<'a, const D: usize, P> { + pub f: &'a dyn Fn(f64, SVector, &P) -> SVector, pub y: SVector, pub t: f64, + pub params: P, pub t0: f64, pub t_end: f64, pub h: f64, pub finished: bool, } -impl<'a, const D: usize> ODE<'a,D> { - pub fn new(f: &'a (dyn Fn(f64, SVector) -> SVector), t0: f64, t_end: f64, y0: SVector) -> Self { +impl<'a, const D: usize, P> ODE<'a,D, P> { + pub fn new( + f: &'a (dyn Fn(f64, SVector, &P) -> SVector), + t0: f64, + t_end: f64, + y0: SVector, + params: P, + ) -> Self { Self { f: f, y: y0, t: t0, + params: params, t0: t0, t_end: t_end, h: 0.001, @@ -35,12 +43,32 @@ mod tests { #[test] fn test_ode_creation() { - fn derivative(_t: f64, y: Vector3) -> Vector3 { -y } + type Params = (); + fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { -y } let y0 = Vector3::new(1.0, 0.0, 0.0); - let ode = ODE::new(&derivative, 0.0, 10.0, y0); + let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); - assert!((ode.f)(0.0, y0) == Vector3::new(-1.0, 0.0, 0.0)); + assert!((ode.f)(0.0, y0, &()) == Vector3::new(-1.0, 0.0, 0.0)); + assert!(ode.y == Vector3::new(1.0, 0.0, 0.0)); + assert!(ode.t == 0.0); + assert!(!ode.finished); + assert!(ode.t_end == 10.0); + } + + #[test] + fn test_ode_with_params() { + type Params = (f64, bool); + let params = (34.0, true); + + fn derivative(t: f64, y: Vector3, p: &Params) -> Vector3 { + if p.1 { -y } else { y * t } + } + + let y0 = Vector3::new(1.0, 0.0, 0.0); + let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); + + assert!((ode.f)(0.0, y0, ¶ms) == Vector3::new(-1.0, 0.0, 0.0)); assert!(ode.y == Vector3::new(1.0, 0.0, 0.0)); assert!(ode.t == 0.0); assert!(!ode.finished); diff --git a/src/problem.rs b/src/problem.rs index d978fc6..0ca7635 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -7,21 +7,21 @@ use super::integrator::Integrator; use super::callback::Callback; #[derive(Clone)] -pub struct Problem<'a, const D: usize, S> +pub struct Problem<'a, const D: usize, S, P> where S: Integrator, { - ode: ODE<'a, D>, + ode: ODE<'a, D, P>, integrator: S, controller: PIController, - callbacks: Vec>, + callbacks: Vec>, } -impl<'a, const D: usize, S> Problem<'a,D,S> +impl<'a, const D: usize, S, P> Problem<'a,D,S,P> where S: Integrator + Copy, { - pub fn new(ode: ODE<'a,D>, integrator: S, controller: PIController) -> Self { + pub fn new(ode: ODE<'a,D,P>, integrator: S, controller: PIController) -> Self { Problem { ode: ode, integrator: integrator, @@ -55,17 +55,17 @@ where if self.callbacks.len() > 0 { // Check for events occurring for callback in &self.callbacks { - println!("{}", (callback.event)(self.ode.t, self.ode.y) * (callback.event)(self.ode.t + step, new_y)); - if (callback.event)(self.ode.t, self.ode.y) * (callback.event)(self.ode.t + step, new_y) < 0.0 { + println!("{}", (callback.event)(self.ode.t, self.ode.y, &self.ode.params) * (callback.event)(self.ode.t + step, new_y, &self.ode.params)); + if (callback.event)(self.ode.t, self.ode.y, &self.ode.params) * (callback.event)(self.ode.t + step, new_y, &self.ode.params) < 0.0 { // If the event crossed zero, then find the root let f = |test_t| { let test_y = self.integrator.step(&self.ode, test_t).0; - (callback.event)(self.ode.t + test_t, test_y) + (callback.event)(self.ode.t + test_t, test_y, &self.ode.params) }; let root = find_root_regula_falsi(0.0, step, &f, &mut 1e-12).unwrap(); step = root; (new_y, _, dense_option) = self.integrator.step(&self.ode, step); - self.ode = (callback.effect)(self.ode); + (callback.effect)(&mut self.ode); } } } @@ -84,7 +84,7 @@ where } } - pub fn with_callback(mut self, callback: Callback<'a, D>) -> Self { + pub fn with_callback(mut self, callback: Callback<'a, D, P>) -> Self { self.callbacks.push(callback); Self { ode: self.ode, @@ -140,10 +140,11 @@ mod tests { #[test] fn test_problem() { - fn derivative(_t: f64, y: Vector3) -> Vector3 { y } + type Params = (); + fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { y } let y0 = Vector3::new(1.0, 1.0, 1.0); - let ode = ODE::new(&derivative, 0.0, 1.0, y0); + let ode = ODE::new(&derivative, 0.0, 1.0, y0, ()); let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); let controller = PIController::default(); @@ -157,15 +158,16 @@ mod tests { #[test] fn test_with_callback() { - fn derivative(_t: f64, y: Vector3) -> Vector3 { y } + type Params = (); + fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { y } let y0 = Vector3::new(1.0, 1.0, 1.0); - let ode = ODE::new(&derivative, 0.0, 5.0, y0); + let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); let controller = PIController::default(); let value_too_high = Callback { - event: &|_: f64, y: SVector| { 10.0 - y[0] }, + event: &|_: f64, y: SVector, _: &Params| { 10.0 - y[0] }, effect: &stop, }; @@ -178,10 +180,11 @@ mod tests { #[test] fn test_with_interpolation() { - fn derivative(_t: f64, y: Vector3) -> Vector3 { y } + type Params = (); + fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { y } let y0 = Vector3::new(1.0, 1.0, 1.0); - let ode = ODE::new(&derivative, 0.0, 10.0, y0); + let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64); let controller = PIController::default();