diff --git a/benches/orbit.rs b/benches/orbit.rs index 74e5201..89aaacf 100644 --- a/benches/orbit.rs +++ b/benches/orbit.rs @@ -24,7 +24,7 @@ fn bench_orbit(c: &mut Criterion) { // Integrate let ode = ODE::new(&derivative, 0.0, 86400.0, y0, params); - let dp45 = DormandPrince45::new(1e-8_f64, 1e-8_f64); + let dp45 = DormandPrince45::new(); let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); c.bench_function("bench_orbit", |b| { diff --git a/benches/simple_1d.rs b/benches/simple_1d.rs index b247d3d..d556f91 100644 --- a/benches/simple_1d.rs +++ b/benches/simple_1d.rs @@ -15,7 +15,7 @@ fn bench_simple_1d(c: &mut Criterion) { // Set up the problem (ODE, Integrator, Controller, and Callbacks) let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); - let dp45 = DormandPrince45::new(1e-1_f64, 1e-6_f64); + let dp45 = DormandPrince45::new().a_tol(1e-6).r_tol(1e-6); let controller = PIController::default(); c.bench_function("bench_simple_1d", |b| { @@ -39,7 +39,7 @@ fn bench_interpolation_1d(c: &mut Criterion) { // Set up the problem (ODE, Integrator, Controller, and Callbacks) let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); - let dp45 = DormandPrince45::new(1e-1_f64, 1e-6_f64); + let dp45 = DormandPrince45::new().a_tol(1e-6).r_tol(1e-6); let controller = PIController::default(); c.bench_function("bench_interpolation_1d", |b| { diff --git a/src/controller.rs b/src/controller.rs index 6258c03..ce1443d 100644 --- a/src/controller.rs +++ b/src/controller.rs @@ -1,5 +1,31 @@ +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum TryStep { + Accepted(f64, f64), + NotYetAccepted(f64), +} + +impl TryStep { + pub fn extract(&self) -> f64 { + match self { + TryStep::Accepted(h, _) => *h, + TryStep::NotYetAccepted(h) => *h, + } + } + + pub fn is_accepted(&self) -> bool { + matches!(self, TryStep::Accepted(_, _)) + } + + pub fn reset(&mut self) -> Result { + match self { + TryStep::Accepted(_, h) => Ok(TryStep::NotYetAccepted(*h)), + TryStep::NotYetAccepted(_) => Err("Cannot reset a NotYetAccepted TryStep"), + } + } +} + pub trait Controller { - fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64); + fn determine_step(&mut self, h: f64, err: f64) -> TryStep; } #[derive(Debug, Clone, Copy)] @@ -11,32 +37,30 @@ pub struct PIController { pub factor_old: f64, pub h_max: f64, pub safety_factor: f64, - pub old_h: f64, + pub next_step_guess: TryStep, } impl Controller for PIController { /// Determines if the previously run step size and error were valid or not. Either way, it also /// returns what the next step size should be - fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64) { + fn determine_step(&mut self, prev_step: f64, err: f64) -> TryStep { let factor_11 = err.powf(self.alpha); let factor = self.factor_c2.max( self.factor_c1 .min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor), ); - let mut h_new = h / factor; if err <= 1.0 { - // Accept the stepsize + let mut h = prev_step / factor; + // Accept the stepsize and provide what the next step size should be self.factor_old = err.max(1.0e-4); - if h_new.abs() > self.h_max { - // If the step is too big - h_new = self.h_max.copysign(h_new); + if h.abs() > self.h_max { + // If the step goes past the maximum allowed, though, we shrink it + h = self.h_max.copysign(h); } - (true, h_new) - // (true, h_new) + TryStep::Accepted(prev_step, h) } else { - // Reject the stepsize and propose a smaller one - h_new = h / (self.factor_c1.min(factor_11 / self.safety_factor)); - (false, h_new) + // Reject the stepsize and propose a smaller one for the current step + TryStep::NotYetAccepted(prev_step / (self.factor_c1.min(factor_11 / self.safety_factor))) } } } @@ -59,7 +83,7 @@ impl PIController { factor_old: 1.0e-4, h_max: h_max.abs(), safety_factor, - old_h: initial_h, + next_step_guess: TryStep::NotYetAccepted(initial_h), } } } @@ -85,6 +109,6 @@ mod tests { assert!(controller.factor_old == 1.0e-4); assert!(controller.h_max == 10.0); assert!(controller.safety_factor == 0.9); - assert!(controller.old_h == 1e-4); + assert!(controller.next_step_guess == TryStep::NotYetAccepted(1e-4)); } } diff --git a/src/integrator/dormand_prince.rs b/src/integrator/dormand_prince.rs index 709fab6..7e3004a 100644 --- a/src/integrator/dormand_prince.rs +++ b/src/integrator/dormand_prince.rs @@ -13,7 +13,7 @@ pub trait DormandPrinceIntegrator<'a> { #[derive(Debug, Clone, Copy)] pub struct DormandPrince45 { - a_tol: f64, + a_tol: SVector, r_tol: f64, } @@ -21,8 +21,17 @@ impl DormandPrince45 where DormandPrince45: Integrator, { - pub fn new(a_tol: f64, r_tol: f64) -> Self { - Self { a_tol, r_tol } + pub fn new() -> Self { + Self { a_tol: SVector::::from_element(1e-8), r_tol: 1e-8 } + } + pub fn a_tol(&mut self, a_tol: f64) -> Self { + Self { a_tol: SVector::::from_element(a_tol), r_tol: self.r_tol } + } + pub fn a_tol_full(&mut self, a_tol: SVector::) -> Self { + Self { a_tol, r_tol: self.r_tol } + } + pub fn r_tol(&mut self, r_tol: f64) -> Self { + Self { a_tol: self.a_tol, r_tol } } } @@ -119,7 +128,7 @@ where let rcont2 = next_y - ode.y; let rcont3 = h * k[0] - rcont2; let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3; - let tol = SVector::::repeat(self.a_tol) + ode.y * self.r_tol; + let tol = self.a_tol + ode.y * self.r_tol; let rcont = vec![rcont1, rcont2, rcont3, rcont4, rcont5]; (next_y, Some((err.component_div(&tol)).norm()), Some(rcont)) } diff --git a/src/integrator/mod.rs b/src/integrator/mod.rs index 00c74e5..a7a6445 100644 --- a/src/integrator/mod.rs +++ b/src/integrator/mod.rs @@ -44,7 +44,7 @@ mod tests { let y0 = Vector3::new(1.0, 1.0, 1.0); let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ()); - let dp45 = DormandPrince45::new(1e-12_f64, 1e-4_f64); + let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-4); // Test that y'(t) = y(t) solves to y(t) = e^t for rkf54 // and also that the error seems reasonable diff --git a/src/lib.rs b/src/lib.rs index 608ef17..20e33bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,7 +38,7 @@ mod tests { // Set up the problem (ODE, Integrator, Controller, and Callbacks) let ode = ODE::new(&derivative, 0.0, 6.3, y0, params); - let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64); + let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6); let controller = PIController::default(); let value_too_high = Callback { @@ -68,7 +68,7 @@ mod tests { // Set up the problem (ODE, Integrator, Controller, and Callbacks) let ode = ODE::new(&derivative, 2.0, 3.0, y0, params); - let dp45 = DormandPrince45::new(1e-8_f64, 1e-8_f64); + let dp45 = DormandPrince45::new(); let controller = PIController::default(); // Solve the problem @@ -105,11 +105,9 @@ mod tests { // Integrate let ode = ODE::new(&derivative, 0.0, 10.0 * period, y0, params); - let dp45 = DormandPrince45::new(1e-12_f64, 1e-12_f64); + let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-12); let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); - let mut problem = Problem::new(ode, dp45, controller); - let solution = problem.solve(); assert_relative_eq!( diff --git a/src/problem.rs b/src/problem.rs index f96c37d..b3f2cb6 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -2,7 +2,7 @@ use nalgebra::SVector; use roots::{find_root_brent, SimpleConvergency}; use super::callback::Callback; -use super::controller::{Controller, PIController}; +use super::controller::{Controller, PIController, TryStep}; use super::integrator::Integrator; use super::ode::ODE; @@ -30,41 +30,57 @@ where } } pub fn solve(&mut self) -> Solution { - let mut convergency = SimpleConvergency { eps: 1e-12, max_iter: 1000 }; + let mut convergency = SimpleConvergency { + eps: 1e-12, + max_iter: 1000, + }; let mut times: Vec = vec![self.ode.t]; let mut states: Vec> = vec![self.ode.y]; let mut dense_coefficients: Vec>> = Vec::new(); - let mut step: f64 = self.controller.old_h; - let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0); while self.ode.t < self.ode.t_end { - let mut dense_option: Option>> = None; - if S::ADAPTIVE { + if self.ode.t + self.controller.next_step_guess.extract() > self.ode.t_end { + // If the next step would go past the end, then just set it to the end + self.controller.next_step_guess = TryStep::NotYetAccepted( + self.ode.t_end - self.ode.t, + ); + } + let (mut new_y, mut curr_step, mut dense_option) = if S::ADAPTIVE { + // First, we try stepping with the "next step guess" to get the error + let (mut trial_y, mut err_option, mut dense_option) = + self.integrator.step(&self.ode, self.controller.next_step_guess.extract()); let mut err = err_option.unwrap(); - let mut accepted: bool = false; - while !accepted { - // Try a step and if that isn't acceptable, then change the step until it is - (accepted, step) = >::determine_step( + // Then we determine whether we need to reduce the step size or not + // If successful, we get the next step guess + let initial_guess = self.controller.next_step_guess.extract(); + let mut next_step_guess = >::determine_step( + &mut self.controller, + initial_guess, + err, + ); + while !next_step_guess.is_accepted() { + // If that step isn't acceptable, then change the step until it is + (trial_y, err_option, dense_option) = + self.integrator.step(&self.ode, next_step_guess.extract()); + next_step_guess = >::determine_step( &mut self.controller, - step, + next_step_guess.extract(), err, ); - (new_y, err_option, dense_option) = self.integrator.step(&self.ode, step); err = err_option.unwrap(); } - self.controller.old_h = step; - self.controller.h_max = self - .controller - .h_max - .min(self.ode.t_end - self.ode.t - step); + // So at this point we can safely assume we have an accepted step + self.controller.next_step_guess = next_step_guess.reset().unwrap(); + (trial_y, next_step_guess.extract(), dense_option) } else { // If fixed time step just step forward one step - (new_y, _, dense_option) = self.integrator.step(&self.ode, step); - } + let (trial_y, _, dense_option) = self.integrator.step(&self.ode, self.controller.next_step_guess.extract()); + (trial_y, self.controller.next_step_guess.extract(), dense_option) + }; if !self.callbacks.is_empty() { // Check for events occurring for callback in &self.callbacks { if (callback.event)(self.ode.t, self.ode.y, &self.ode.params) - * (callback.event)(self.ode.t + step, new_y, &self.ode.params) + * (callback.event)(self.ode.t + curr_step, new_y, &self.ode.params) < 0.0 { // If the event crossed zero, then find the root @@ -72,15 +88,15 @@ where let test_y = self.integrator.step(&self.ode, test_t).0; (callback.event)(self.ode.t + test_t, test_y, &self.ode.params) }; - let root = find_root_brent(0.0, step, &f, &mut convergency).unwrap(); - step = root; - (new_y, _, dense_option) = self.integrator.step(&self.ode, step); + let root = find_root_brent(0.0, curr_step, &f, &mut convergency).unwrap(); + curr_step = root; + (new_y, _, dense_option) = self.integrator.step(&self.ode, curr_step); (callback.effect)(&mut self.ode); } } } self.ode.y = new_y; - self.ode.t += step; + self.ode.t += curr_step; times.push(self.ode.t); states.push(self.ode.y); // TODO: Implement third order interpolation for non-dense algorithms @@ -142,7 +158,7 @@ where let t_end = times[end_index]; self.integrator .interpolate(t_start, t_end, &self.dense[end_index - 1], t) - }, + } } } } @@ -165,7 +181,7 @@ mod tests { let y0 = Vector3::new(1.0, 1.0, 1.0); let ode = ODE::new(&derivative, 0.0, 1.0, y0, ()); - let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); + let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5); let controller = PIController::default(); let mut problem = Problem::new(ode, dp45, controller); @@ -189,7 +205,7 @@ mod tests { let y0 = Vector3::new(1.0, 1.0, 1.0); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); - let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); + let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5); let controller = PIController::default(); let value_too_high = Callback { @@ -200,7 +216,11 @@ mod tests { let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high); let solution = problem.solve(); - assert_relative_eq!(solution.states.last().unwrap()[0], 10.0, max_relative = 1e-11); + assert_relative_eq!( + solution.states.last().unwrap()[0], + 10.0, + max_relative = 1e-11 + ); } #[test] @@ -212,7 +232,7 @@ mod tests { let y0 = Vector3::new(1.0, 1.0, 1.0); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); - let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64); + let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6); let controller = PIController::default(); let mut problem = Problem::new(ode, dp45, controller);