From 43ec9eb0acc0fdd12d4beef2fccfaa36ea141ebb Mon Sep 17 00:00:00 2001 From: Connor Johnstone Date: Tue, 14 Mar 2023 16:55:06 -0600 Subject: [PATCH] Finished dopri5, interpolation, and callbacks --- Cargo.toml | 1 + src/callback.rs | 34 +++++++ src/controller.rs | 2 - src/integrator/dormand_prince.rs | 48 +++++++--- src/integrator/mod.rs | 16 ++-- src/integrator/rosenbrock.rs | 7 +- src/lib.rs | 38 ++++---- src/ode.rs | 8 -- src/problem.rs | 147 ++++++++++++++++++++++++++----- 9 files changed, 228 insertions(+), 73 deletions(-) create mode 100644 src/callback.rs diff --git a/Cargo.toml b/Cargo.toml index 4ca52a6..7172348 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" serde = { version = "1.0", features = ["derive"] } nalgebra = { version = "0.32", features = ["serde-serialize"] } num-traits = "0.2.15" +roots = "0.0.8" [dev-dependencies] approx = "0.5" diff --git a/src/callback.rs b/src/callback.rs new file mode 100644 index 0000000..58591db --- /dev/null +++ b/src/callback.rs @@ -0,0 +1,34 @@ +use nalgebra::SVector; +use super::ode::ODE; + +/// A function that takes in a time and a state and outputs a single float value +/// +/// The integration solver will check this function for zero crossings +#[derive(Clone, Copy)] +pub struct Callback<'a, const D: usize> { + /// The function to check for zero crossings + pub event: &'a dyn Fn(f64, SVector) -> f64, + + /// The function to change the ODE + pub effect: &'a dyn Fn(ODE) -> ODE, +} + +/// A convenience function for stopping the integration +pub fn stop(ode: ODE) -> ODE { + let mut new_ode = ode.clone(); + new_ode.t_end = new_ode.t; + new_ode +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_callbacks() { + let _value_too_high = Callback { + event: &|_: f64, y: SVector| { 10.0 - y[0] }, + effect: &stop, + }; + } +} diff --git a/src/controller.rs b/src/controller.rs index aee00c8..3cf5fe1 100644 --- a/src/controller.rs +++ b/src/controller.rs @@ -21,8 +21,6 @@ impl Controller for PIController { let factor_11 = err.powf(self.alpha); let factor = self.factor_c2.max(self.factor_c1.min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor)); let mut h_new = h / factor; - // let mut h_new = 0.9 * h * err.powf(-1.0 / 5.0); - println!("err: {}\th_new: {}", err, h_new); if err <= 1.0 { // Accept the stepsize self.factor_old = err.max(1.0e-4); diff --git a/src/integrator/dormand_prince.rs b/src/integrator/dormand_prince.rs index 7b8c43a..a1cc201 100644 --- a/src/integrator/dormand_prince.rs +++ b/src/integrator/dormand_prince.rs @@ -8,10 +8,11 @@ pub trait DormandPrinceIntegrator { const A: &'static [f64]; const B: &'static [f64]; const C: &'static [f64]; + const D: &'static [f64]; } +#[derive(Debug, Clone, Copy)] pub struct DormandPrince45 { - k: Vec>, a_tol: f64, r_tol: f64, } @@ -19,7 +20,6 @@ pub struct DormandPrince45 { impl DormandPrince45 where DormandPrince45: Integrator { pub fn new(a_tol: f64, r_tol: f64) -> Self { Self { - k: vec![SVector::::zeros(); Self::STAGES], a_tol: a_tol, r_tol: r_tol, } @@ -75,35 +75,61 @@ impl DormandPrinceIntegrator for DormandPrince45 { 1.0, 1.0, ]; + const D: &'static [f64] = &[ + -12715105075.0 / 11282082432.0, + 0.0, + 87487479700.0 / 32700410799.0, + -10690763975.0 / 1880347072.0, + 701980252875.0 / 199316789632.0, + -1453857185.0 / 822651844.0, + 69997945.0 / 29380423.0, + ]; } impl Integrator for DormandPrince45 where DormandPrince45: DormandPrinceIntegrator, { + const ORDER: usize = 5; const STAGES: usize = 7; + const ADAPTIVE: bool = true; + const DENSE: bool = true; - fn step(&mut self, ode: &ODE, h: f64) -> (SVector, Option) { + fn step(&self, ode: &ODE, h: f64) -> (SVector, Option, Option>>) { + let mut k: Vec> = vec![SVector::::zeros(); Self::STAGES]; let mut next_y = ode.y.clone(); let mut err = SVector::::zeros(); + let mut rcont5 = SVector::::zeros(); // Do the first of the summations - self.k[0] = (ode.f)(ode.t, ode.y); - next_y += self.k[0] * Self::B[0] * h; - err += self.k[0] * (Self::B[0] - Self::B[Self::STAGES]) * h; + k[0] = (ode.f)(ode.t, ode.y); + next_y += k[0] * Self::B[0] * h; + err += k[0] * (Self::B[0] - Self::B[Self::STAGES]) * h; + let rcont1 = ode.y; + rcont5 += k[0] * h * Self::D[0]; // Then the rest for i in 1..Self::STAGES { // Compute the ks let mut y_term = SVector::::zeros(); for j in 0..i { - y_term += self.k[i-j-1] * Self::A[( i * (i - 1) ) / 2 + j]; + y_term += k[j] * Self::A[( i * (i - 1) ) / 2 + j]; } - self.k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h); + k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h); // Use that and bis to calculate the y and error terms - next_y += self.k[i] * h * Self::B[i]; - err += self.k[i] * (Self::B[i] - Self::B[i + Self::STAGES]) * h; + next_y += k[i] * h * Self::B[i]; + err += k[i] * (Self::B[i] - Self::B[i + Self::STAGES]) * h; + rcont5 += k[i] * h * Self::D[i]; } + let rcont2 = next_y - ode.y; + let rcont3 = h * k[0] - rcont2; + let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3; let tol = SVector::::repeat(self.a_tol) + ode.y * self.r_tol; - (next_y, Some((err.component_div(&tol)).norm())) + let rcont = vec![ rcont1, rcont2, rcont3, rcont4, rcont5, ]; + (next_y, Some((err.component_div(&tol)).norm()), Some(rcont)) + } + fn interpolate(&self, t_start: f64, t_end: f64, dense: &Vec>, t: f64) -> SVector { + let s = (t - t_start)/(t_end - t_start); + let s1 = 1.0 - s; + dense[0] + (dense[1] + (dense[2] + (dense[3] + dense[4] * s1) * s) * s1) * s } } diff --git a/src/integrator/mod.rs b/src/integrator/mod.rs index 6b4d2ae..4af2ab7 100644 --- a/src/integrator/mod.rs +++ b/src/integrator/mod.rs @@ -3,12 +3,18 @@ use nalgebra::SVector; use super::ode::ODE; pub mod dormand_prince; -pub mod rosenbrock; +// pub mod rosenbrock; /// Integrator Trait pub trait Integrator { + const ORDER: usize; const STAGES: usize; - fn step(&mut self, ode: &ODE, h: f64) -> (SVector, Option); + const ADAPTIVE: bool; + const DENSE: bool; + /// Returns a new y value, then possibly an error value, and possibly a dense output + /// coefficient set + fn step(&self, ode: &ODE, h: f64) -> (SVector, Option, Option>>); + fn interpolate(&self, t_start: f64, t_end: f64, dense: &Vec>, t: f64) -> SVector; } @@ -27,13 +33,13 @@ mod tests { let y0 = Vector3::new(1.0, 1.0, 1.0); let mut ode = ODE::new(&derivative, 0.0, 4.0, y0); - let mut dp45 = DormandPrince45::new(1e-12_f64, 1e-4_f64); + let dp45 = DormandPrince45::new(1e-12_f64, 1e-4_f64); // Test that y'(t) = y(t) solves to y(t) = e^t for rkf54 // and also that the error seems reasonable - let step = 0.0005; + let step = 0.001; while ode.t < ode.t_end { - let (new_y, err) = dp45.step(&ode, step); + let (new_y, err, _) = dp45.step(&ode, step); ode.y = new_y; ode.t += step; assert_relative_eq!(ode.y[0], ode.t.exp(), max_relative=0.01); diff --git a/src/integrator/rosenbrock.rs b/src/integrator/rosenbrock.rs index 78d8857..bdf320c 100644 --- a/src/integrator/rosenbrock.rs +++ b/src/integrator/rosenbrock.rs @@ -91,11 +91,12 @@ where Rodas4: RosenbrockIntegrator, { const STAGES: usize = 6; + const ADAPTIVE: bool = true; // TODO: Finish this - fn step(&mut self, ode: &ODE, h: f64) -> (SVector, Option) { - let mut next_y = ode.y.clone(); - let mut err = SVector::::zeros(); + fn step(&self, ode: &ODE, _h: f64) -> (SVector, Option) { + let next_y = ode.y.clone(); + let err = SVector::::zeros(); (next_y, Some(err.norm())) } } diff --git a/src/lib.rs b/src/lib.rs index f830607..51fba1f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ pub mod ode; pub mod integrator; pub mod controller; -// pub mod callback; +pub mod callback; pub mod problem; @@ -23,7 +23,6 @@ mod tests { // Calculate one period let a = 6.7781363e6_f64; let period = 2.0 * PI * (a.powi(3)/3.98600441500000e14).sqrt(); - println!("{}", period); // Set up the system fn derivative(_t: f64, state: Vector6) -> Vector6 { @@ -31,32 +30,29 @@ mod tests { Vector6::new(state[3], state[4], state[5], acc[0], acc[1], acc[2]) } let y0 = Vector6::new( - 4.26387250e+06, - 5.14619397e+06, - 1.13102192e+06, - -5.92345023e+03, - 4.49679662e+03, - 1.87038714e+03, + 4.263868426884883e6, + 5.146189057155391e6, + 1.1310208421331816e6, + -5923.454461876975, + 4496.802639690076, + 1870.3893008991558, ); // Integrate - let ode = ODE::new(&derivative, 0.0, period, y0); - let dp45 = DormandPrince45::new(1e-12_f64, 1e-8_f64); - let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 10.0, 0.9, 1e-4); + let ode = ODE::new(&derivative, 0.0, 10.0*period, y0); + let dp45 = DormandPrince45::new(1e-12_f64, 1e-12_f64); + let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); let mut problem = Problem::new(ode, dp45, controller); let solution = problem.solve(); - println!("{}", solution.times.len()); - // panic!(); - // TODO: Something still isn't right with these tolerances I think... - assert_relative_eq!(solution.times[solution.states.len()-1], period, max_relative=1e-7); - assert_relative_eq!(solution.states[solution.states.len()-1][0], y0[0], max_relative=1e-3); - assert_relative_eq!(solution.states[solution.states.len()-1][1], y0[1], max_relative=1e-3); - assert_relative_eq!(solution.states[solution.states.len()-1][2], y0[2], max_relative=1e-3); - assert_relative_eq!(solution.states[solution.states.len()-1][3], y0[3], max_relative=1e-3); - assert_relative_eq!(solution.states[solution.states.len()-1][4], y0[4], max_relative=1e-3); - assert_relative_eq!(solution.states[solution.states.len()-1][5], y0[5], max_relative=1e-3); + assert_relative_eq!(solution.times[solution.states.len()-1], 10.0 * period, max_relative=1e-12); + assert_relative_eq!(solution.states[solution.states.len()-1][0], y0[0], max_relative=1e-9); + assert_relative_eq!(solution.states[solution.states.len()-1][1], y0[1], max_relative=1e-9); + assert_relative_eq!(solution.states[solution.states.len()-1][2], y0[2], max_relative=1e-9); + assert_relative_eq!(solution.states[solution.states.len()-1][3], y0[3], max_relative=1e-9); + assert_relative_eq!(solution.states[solution.states.len()-1][4], y0[4], max_relative=1e-9); + assert_relative_eq!(solution.states[solution.states.len()-1][5], y0[5], max_relative=1e-9); } } diff --git a/src/ode.rs b/src/ode.rs index be8765d..3cd8c52 100644 --- a/src/ode.rs +++ b/src/ode.rs @@ -1,13 +1,5 @@ use nalgebra::SVector; -/// A System trait. -/// -/// The user will have to define their own system. They are free to add params to their system -/// definition and use those in the derivative function -pub trait SystemTrait { - fn derivative(&self, t: T, y: SVector) -> SVector; -} - /// The basic ODE object that will be passed around. The type (T) and the size (D) will be /// determined upon creation of the object #[derive(Clone, Copy)] diff --git a/src/problem.rs b/src/problem.rs index 2a75211..69dd142 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -1,9 +1,12 @@ use nalgebra::SVector; +use roots::find_root_regula_falsi; use super::ode::ODE; use super::controller::{Controller, PIController}; use super::integrator::Integrator; +use super::callback::Callback; +#[derive(Clone)] pub struct Problem<'a, const D: usize, S> where S: Integrator, @@ -11,56 +14,119 @@ where ode: ODE<'a, D>, integrator: S, controller: PIController, + callbacks: Vec>, } impl<'a, const D: usize, S> Problem<'a,D,S> where - S: Integrator, + S: Integrator + Copy, { pub fn new(ode: ODE<'a,D>, integrator: S, controller: PIController) -> Self { Problem { ode: ode, integrator: integrator, controller: controller, + callbacks: Vec::new(), } } - pub fn solve(&mut self) -> Solution { - let mut times: Vec:: = Vec::new(); - let mut states: Vec::> = Vec::new(); + pub fn solve(&mut self) -> Solution { + let mut times: Vec:: = vec![self.ode.t]; + let mut states: Vec::> = vec![self.ode.y]; + let mut dense_coefficients: Vec::>> = Vec::new(); let mut step: f64 = self.controller.old_h; - times.push(self.ode.t); - states.push(self.ode.y); - let (mut new_y, mut err_option) = self.integrator.step(&self.ode, 0.0); + let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0); while self.ode.t < self.ode.t_end { - match err_option { - Some(mut err) => { - // Adaptive Step Size - let mut accepted: bool = false; - while !accepted { - (accepted, step) = >::determine_step(&mut self.controller, step, err); - (new_y, err_option) = self.integrator.step(&self.ode, step); - err = err_option.unwrap(); + let mut dense_option: Option>> = None; + if S::ADAPTIVE { + let mut err = err_option.unwrap(); + let mut accepted: bool = false; + while !accepted { + // Try a step and if that isn't acceptable, then change the step until it is + (accepted, step) = >::determine_step(&mut self.controller, step, err); + (new_y, err_option, dense_option) = self.integrator.step(&self.ode, step); + err = err_option.unwrap(); + } + self.controller.old_h = step; + self.controller.h_max = self.controller.h_max.min(self.ode.t_end - self.ode.t - step); + } else { + // If fixed time step just step forward one step + (new_y, _, dense_option) = self.integrator.step(&self.ode, step); + } + if self.callbacks.len() > 0 { + // Check for events occurring + for callback in &self.callbacks { + println!("{}", (callback.event)(self.ode.t, self.ode.y) * (callback.event)(self.ode.t + step, new_y)); + if (callback.event)(self.ode.t, self.ode.y) * (callback.event)(self.ode.t + step, new_y) < 0.0 { + // If the event crossed zero, then find the root + let f = |test_t| { + let test_y = self.integrator.step(&self.ode, test_t).0; + (callback.event)(self.ode.t + test_t, test_y) + }; + let root = find_root_regula_falsi(0.0, step, &f, &mut 1e-12).unwrap(); + step = root; + (new_y, _, dense_option) = self.integrator.step(&self.ode, step); + self.ode = (callback.effect)(self.ode); } - self.controller.old_h = step; - self.controller.h_max = self.controller.h_max.min(self.ode.t_end - self.ode.t - step); - }, - None => {}, - }; + } + } self.ode.y = new_y; self.ode.t += step; times.push(self.ode.t); states.push(self.ode.y); + // TODO: Implement third order interpolation for non-dense algorithms + dense_coefficients.push(dense_option.unwrap()); } Solution { + integrator: self.integrator, times: times, states: states, + dense: dense_coefficients, + } + } + + pub fn with_callback(mut self, callback: Callback<'a, D>) -> Self { + self.callbacks.push(callback); + Self { + ode: self.ode, + integrator: self.integrator, + controller: self.controller, + callbacks: self.callbacks, } } } -pub struct Solution { +pub struct Solution where S: Integrator { + pub integrator: S, pub times: Vec, pub states: Vec>, + pub dense: Vec::>>, +} + +impl Solution where S: Integrator { + pub fn interpolate(&self, t: f64) -> SVector { + // First check that the t is within bounds + let last = self.times.last().unwrap(); + let first = self.times.first().unwrap(); + + // TODO: Improve these errors + let mut times = self.times.clone(); + if *first > *last { times.reverse(); } + if t < *first || t > *last { panic!(); } + + // Then find the two t values closest to the desired t + let mut end_index: usize = 0; + for (i, time) in self.times.iter().enumerate() { + if time > &t { + end_index = i; + break; + } + } + + // Then send that to the integrator + let t_start = times[end_index - 1]; + let t_end = times[end_index]; + self.integrator.interpolate(t_start, t_end, &self.dense[end_index - 1], t) + } } #[cfg(test)] @@ -70,6 +136,7 @@ mod tests { use approx::assert_relative_eq; use crate::integrator::dormand_prince::DormandPrince45; use crate::controller::PIController; + use crate::callback::stop; #[test] fn test_problem() { @@ -83,10 +150,44 @@ mod tests { let mut problem = Problem::new(ode, dp45, controller); let solution = problem.solve(); - // println!("{}", solution.times.len()); - // panic!(); solution.times.iter().zip(solution.states.iter()).for_each(|(time, state)| { assert_relative_eq!(state[0], time.exp(), max_relative=1e-2); }) } + + #[test] + fn test_with_callback() { + fn derivative(_t: f64, y: Vector3) -> Vector3 { y } + let y0 = Vector3::new(1.0, 1.0, 1.0); + + let ode = ODE::new(&derivative, 0.0, 5.0, y0); + let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); + let controller = PIController::new(0.17, 0.04, 10.0, 0.2, 0.1, 0.9, 1e-8); + + let value_too_high = Callback { + event: &|_: f64, y: SVector| { 10.0 - y[0] }, + effect: &stop, + }; + + let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high); + let solution = problem.solve(); + + println!("{}", solution.states.last().unwrap()[0]); + assert!(solution.states.last().unwrap()[0] == 10.0); + } + + #[test] + fn test_with_interpolation() { + fn derivative(_t: f64, y: Vector3) -> Vector3 { y } + let y0 = Vector3::new(1.0, 1.0, 1.0); + + let ode = ODE::new(&derivative, 0.0, 10.0, y0); + let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64); + let controller = PIController::new(0.17, 0.04, 10.0, 0.2, 0.1, 0.9, 1e-8); + + let mut problem = Problem::new(ode, dp45, controller); + let solution = problem.solve(); + + assert_relative_eq!(solution.interpolate(8.8)[0], 8.8_f64.exp(), max_relative=1e-6); + } }