Updated the way that steps are handled

This commit is contained in:
Connor Johnstone
2025-08-12 15:54:23 -04:00
parent 9075dac669
commit 2659d78582
7 changed files with 108 additions and 57 deletions

View File

@@ -24,7 +24,7 @@ fn bench_orbit(c: &mut Criterion) {
// Integrate // Integrate
let ode = ODE::new(&derivative, 0.0, 86400.0, y0, params); let ode = ODE::new(&derivative, 0.0, 86400.0, y0, params);
let dp45 = DormandPrince45::new(1e-8_f64, 1e-8_f64); let dp45 = DormandPrince45::new();
let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01);
c.bench_function("bench_orbit", |b| { c.bench_function("bench_orbit", |b| {

View File

@@ -15,7 +15,7 @@ fn bench_simple_1d(c: &mut Criterion) {
// Set up the problem (ODE, Integrator, Controller, and Callbacks) // Set up the problem (ODE, Integrator, Controller, and Callbacks)
let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); let ode = ODE::new(&derivative, 0.0, 10.0, y0, params);
let dp45 = DormandPrince45::new(1e-1_f64, 1e-6_f64); let dp45 = DormandPrince45::new().a_tol(1e-6).r_tol(1e-6);
let controller = PIController::default(); let controller = PIController::default();
c.bench_function("bench_simple_1d", |b| { c.bench_function("bench_simple_1d", |b| {
@@ -39,7 +39,7 @@ fn bench_interpolation_1d(c: &mut Criterion) {
// Set up the problem (ODE, Integrator, Controller, and Callbacks) // Set up the problem (ODE, Integrator, Controller, and Callbacks)
let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); let ode = ODE::new(&derivative, 0.0, 10.0, y0, params);
let dp45 = DormandPrince45::new(1e-1_f64, 1e-6_f64); let dp45 = DormandPrince45::new().a_tol(1e-6).r_tol(1e-6);
let controller = PIController::default(); let controller = PIController::default();
c.bench_function("bench_interpolation_1d", |b| { c.bench_function("bench_interpolation_1d", |b| {

View File

@@ -1,5 +1,31 @@
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TryStep {
Accepted(f64, f64),
NotYetAccepted(f64),
}
impl TryStep {
pub fn extract(&self) -> f64 {
match self {
TryStep::Accepted(h, _) => *h,
TryStep::NotYetAccepted(h) => *h,
}
}
pub fn is_accepted(&self) -> bool {
matches!(self, TryStep::Accepted(_, _))
}
pub fn reset(&mut self) -> Result<TryStep, &str> {
match self {
TryStep::Accepted(_, h) => Ok(TryStep::NotYetAccepted(*h)),
TryStep::NotYetAccepted(_) => Err("Cannot reset a NotYetAccepted TryStep"),
}
}
}
pub trait Controller<const D: usize> { pub trait Controller<const D: usize> {
fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64); fn determine_step(&mut self, h: f64, err: f64) -> TryStep;
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@@ -11,32 +37,30 @@ pub struct PIController {
pub factor_old: f64, pub factor_old: f64,
pub h_max: f64, pub h_max: f64,
pub safety_factor: f64, pub safety_factor: f64,
pub old_h: f64, pub next_step_guess: TryStep,
} }
impl<const D: usize> Controller<D> for PIController { impl<const D: usize> Controller<D> for PIController {
/// Determines if the previously run step size and error were valid or not. Either way, it also /// Determines if the previously run step size and error were valid or not. Either way, it also
/// returns what the next step size should be /// returns what the next step size should be
fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64) { fn determine_step(&mut self, prev_step: f64, err: f64) -> TryStep {
let factor_11 = err.powf(self.alpha); let factor_11 = err.powf(self.alpha);
let factor = self.factor_c2.max( let factor = self.factor_c2.max(
self.factor_c1 self.factor_c1
.min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor), .min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor),
); );
let mut h_new = h / factor;
if err <= 1.0 { if err <= 1.0 {
// Accept the stepsize let mut h = prev_step / factor;
// Accept the stepsize and provide what the next step size should be
self.factor_old = err.max(1.0e-4); self.factor_old = err.max(1.0e-4);
if h_new.abs() > self.h_max { if h.abs() > self.h_max {
// If the step is too big // If the step goes past the maximum allowed, though, we shrink it
h_new = self.h_max.copysign(h_new); h = self.h_max.copysign(h);
} }
(true, h_new) TryStep::Accepted(prev_step, h)
// (true, h_new)
} else { } else {
// Reject the stepsize and propose a smaller one // Reject the stepsize and propose a smaller one for the current step
h_new = h / (self.factor_c1.min(factor_11 / self.safety_factor)); TryStep::NotYetAccepted(prev_step / (self.factor_c1.min(factor_11 / self.safety_factor)))
(false, h_new)
} }
} }
} }
@@ -59,7 +83,7 @@ impl PIController {
factor_old: 1.0e-4, factor_old: 1.0e-4,
h_max: h_max.abs(), h_max: h_max.abs(),
safety_factor, safety_factor,
old_h: initial_h, next_step_guess: TryStep::NotYetAccepted(initial_h),
} }
} }
} }
@@ -85,6 +109,6 @@ mod tests {
assert!(controller.factor_old == 1.0e-4); assert!(controller.factor_old == 1.0e-4);
assert!(controller.h_max == 10.0); assert!(controller.h_max == 10.0);
assert!(controller.safety_factor == 0.9); assert!(controller.safety_factor == 0.9);
assert!(controller.old_h == 1e-4); assert!(controller.next_step_guess == TryStep::NotYetAccepted(1e-4));
} }
} }

View File

@@ -13,7 +13,7 @@ pub trait DormandPrinceIntegrator<'a> {
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct DormandPrince45<const D: usize> { pub struct DormandPrince45<const D: usize> {
a_tol: f64, a_tol: SVector<f64,D>,
r_tol: f64, r_tol: f64,
} }
@@ -21,8 +21,17 @@ impl<const D: usize> DormandPrince45<D>
where where
DormandPrince45<D>: Integrator<D>, DormandPrince45<D>: Integrator<D>,
{ {
pub fn new(a_tol: f64, r_tol: f64) -> Self { pub fn new() -> Self {
Self { a_tol, r_tol } Self { a_tol: SVector::<f64,D>::from_element(1e-8), r_tol: 1e-8 }
}
pub fn a_tol(&mut self, a_tol: f64) -> Self {
Self { a_tol: SVector::<f64,D>::from_element(a_tol), r_tol: self.r_tol }
}
pub fn a_tol_full(&mut self, a_tol: SVector::<f64,D>) -> Self {
Self { a_tol, r_tol: self.r_tol }
}
pub fn r_tol(&mut self, r_tol: f64) -> Self {
Self { a_tol: self.a_tol, r_tol }
} }
} }
@@ -119,7 +128,7 @@ where
let rcont2 = next_y - ode.y; let rcont2 = next_y - ode.y;
let rcont3 = h * k[0] - rcont2; let rcont3 = h * k[0] - rcont2;
let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3; let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3;
let tol = SVector::<f64, D>::repeat(self.a_tol) + ode.y * self.r_tol; let tol = self.a_tol + ode.y * self.r_tol;
let rcont = vec![rcont1, rcont2, rcont3, rcont4, rcont5]; let rcont = vec![rcont1, rcont2, rcont3, rcont4, rcont5];
(next_y, Some((err.component_div(&tol)).norm()), Some(rcont)) (next_y, Some((err.component_div(&tol)).norm()), Some(rcont))
} }

View File

@@ -44,7 +44,7 @@ mod tests {
let y0 = Vector3::new(1.0, 1.0, 1.0); let y0 = Vector3::new(1.0, 1.0, 1.0);
let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ()); let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ());
let dp45 = DormandPrince45::new(1e-12_f64, 1e-4_f64); let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-4);
// Test that y'(t) = y(t) solves to y(t) = e^t for rkf54 // Test that y'(t) = y(t) solves to y(t) = e^t for rkf54
// and also that the error seems reasonable // and also that the error seems reasonable

View File

@@ -38,7 +38,7 @@ mod tests {
// Set up the problem (ODE, Integrator, Controller, and Callbacks) // Set up the problem (ODE, Integrator, Controller, and Callbacks)
let ode = ODE::new(&derivative, 0.0, 6.3, y0, params); let ode = ODE::new(&derivative, 0.0, 6.3, y0, params);
let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64); let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6);
let controller = PIController::default(); let controller = PIController::default();
let value_too_high = Callback { let value_too_high = Callback {
@@ -68,7 +68,7 @@ mod tests {
// Set up the problem (ODE, Integrator, Controller, and Callbacks) // Set up the problem (ODE, Integrator, Controller, and Callbacks)
let ode = ODE::new(&derivative, 2.0, 3.0, y0, params); let ode = ODE::new(&derivative, 2.0, 3.0, y0, params);
let dp45 = DormandPrince45::new(1e-8_f64, 1e-8_f64); let dp45 = DormandPrince45::new();
let controller = PIController::default(); let controller = PIController::default();
// Solve the problem // Solve the problem
@@ -105,11 +105,9 @@ mod tests {
// Integrate // Integrate
let ode = ODE::new(&derivative, 0.0, 10.0 * period, y0, params); let ode = ODE::new(&derivative, 0.0, 10.0 * period, y0, params);
let dp45 = DormandPrince45::new(1e-12_f64, 1e-12_f64); let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-12);
let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01);
let mut problem = Problem::new(ode, dp45, controller); let mut problem = Problem::new(ode, dp45, controller);
let solution = problem.solve(); let solution = problem.solve();
assert_relative_eq!( assert_relative_eq!(

View File

@@ -2,7 +2,7 @@ use nalgebra::SVector;
use roots::{find_root_brent, SimpleConvergency}; use roots::{find_root_brent, SimpleConvergency};
use super::callback::Callback; use super::callback::Callback;
use super::controller::{Controller, PIController}; use super::controller::{Controller, PIController, TryStep};
use super::integrator::Integrator; use super::integrator::Integrator;
use super::ode::ODE; use super::ode::ODE;
@@ -30,41 +30,57 @@ where
} }
} }
pub fn solve(&mut self) -> Solution<S, D> { pub fn solve(&mut self) -> Solution<S, D> {
let mut convergency = SimpleConvergency { eps: 1e-12, max_iter: 1000 }; let mut convergency = SimpleConvergency {
eps: 1e-12,
max_iter: 1000,
};
let mut times: Vec<f64> = vec![self.ode.t]; let mut times: Vec<f64> = vec![self.ode.t];
let mut states: Vec<SVector<f64, D>> = vec![self.ode.y]; let mut states: Vec<SVector<f64, D>> = vec![self.ode.y];
let mut dense_coefficients: Vec<Vec<SVector<f64, D>>> = Vec::new(); let mut dense_coefficients: Vec<Vec<SVector<f64, D>>> = Vec::new();
let mut step: f64 = self.controller.old_h;
let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0);
while self.ode.t < self.ode.t_end { while self.ode.t < self.ode.t_end {
let mut dense_option: Option<Vec<SVector<f64, D>>> = None; if self.ode.t + self.controller.next_step_guess.extract() > self.ode.t_end {
if S::ADAPTIVE { // If the next step would go past the end, then just set it to the end
self.controller.next_step_guess = TryStep::NotYetAccepted(
self.ode.t_end - self.ode.t,
);
}
let (mut new_y, mut curr_step, mut dense_option) = if S::ADAPTIVE {
// First, we try stepping with the "next step guess" to get the error
let (mut trial_y, mut err_option, mut dense_option) =
self.integrator.step(&self.ode, self.controller.next_step_guess.extract());
let mut err = err_option.unwrap(); let mut err = err_option.unwrap();
let mut accepted: bool = false; // Then we determine whether we need to reduce the step size or not
while !accepted { // If successful, we get the next step guess
// Try a step and if that isn't acceptable, then change the step until it is let initial_guess = self.controller.next_step_guess.extract();
(accepted, step) = <PIController as Controller<D>>::determine_step( let mut next_step_guess = <PIController as Controller<D>>::determine_step(
&mut self.controller,
initial_guess,
err,
);
while !next_step_guess.is_accepted() {
// If that step isn't acceptable, then change the step until it is
(trial_y, err_option, dense_option) =
self.integrator.step(&self.ode, next_step_guess.extract());
next_step_guess = <PIController as Controller<D>>::determine_step(
&mut self.controller, &mut self.controller,
step, next_step_guess.extract(),
err, err,
); );
(new_y, err_option, dense_option) = self.integrator.step(&self.ode, step);
err = err_option.unwrap(); err = err_option.unwrap();
} }
self.controller.old_h = step; // So at this point we can safely assume we have an accepted step
self.controller.h_max = self self.controller.next_step_guess = next_step_guess.reset().unwrap();
.controller (trial_y, next_step_guess.extract(), dense_option)
.h_max
.min(self.ode.t_end - self.ode.t - step);
} else { } else {
// If fixed time step just step forward one step // If fixed time step just step forward one step
(new_y, _, dense_option) = self.integrator.step(&self.ode, step); let (trial_y, _, dense_option) = self.integrator.step(&self.ode, self.controller.next_step_guess.extract());
} (trial_y, self.controller.next_step_guess.extract(), dense_option)
};
if !self.callbacks.is_empty() { if !self.callbacks.is_empty() {
// Check for events occurring // Check for events occurring
for callback in &self.callbacks { for callback in &self.callbacks {
if (callback.event)(self.ode.t, self.ode.y, &self.ode.params) if (callback.event)(self.ode.t, self.ode.y, &self.ode.params)
* (callback.event)(self.ode.t + step, new_y, &self.ode.params) * (callback.event)(self.ode.t + curr_step, new_y, &self.ode.params)
< 0.0 < 0.0
{ {
// If the event crossed zero, then find the root // If the event crossed zero, then find the root
@@ -72,15 +88,15 @@ where
let test_y = self.integrator.step(&self.ode, test_t).0; let test_y = self.integrator.step(&self.ode, test_t).0;
(callback.event)(self.ode.t + test_t, test_y, &self.ode.params) (callback.event)(self.ode.t + test_t, test_y, &self.ode.params)
}; };
let root = find_root_brent(0.0, step, &f, &mut convergency).unwrap(); let root = find_root_brent(0.0, curr_step, &f, &mut convergency).unwrap();
step = root; curr_step = root;
(new_y, _, dense_option) = self.integrator.step(&self.ode, step); (new_y, _, dense_option) = self.integrator.step(&self.ode, curr_step);
(callback.effect)(&mut self.ode); (callback.effect)(&mut self.ode);
} }
} }
} }
self.ode.y = new_y; self.ode.y = new_y;
self.ode.t += step; self.ode.t += curr_step;
times.push(self.ode.t); times.push(self.ode.t);
states.push(self.ode.y); states.push(self.ode.y);
// TODO: Implement third order interpolation for non-dense algorithms // TODO: Implement third order interpolation for non-dense algorithms
@@ -142,7 +158,7 @@ where
let t_end = times[end_index]; let t_end = times[end_index];
self.integrator self.integrator
.interpolate(t_start, t_end, &self.dense[end_index - 1], t) .interpolate(t_start, t_end, &self.dense[end_index - 1], t)
}, }
} }
} }
} }
@@ -165,7 +181,7 @@ mod tests {
let y0 = Vector3::new(1.0, 1.0, 1.0); let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 1.0, y0, ()); let ode = ODE::new(&derivative, 0.0, 1.0, y0, ());
let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5);
let controller = PIController::default(); let controller = PIController::default();
let mut problem = Problem::new(ode, dp45, controller); let mut problem = Problem::new(ode, dp45, controller);
@@ -189,7 +205,7 @@ mod tests {
let y0 = Vector3::new(1.0, 1.0, 1.0); let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5);
let controller = PIController::default(); let controller = PIController::default();
let value_too_high = Callback { let value_too_high = Callback {
@@ -200,7 +216,11 @@ mod tests {
let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high); let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high);
let solution = problem.solve(); let solution = problem.solve();
assert_relative_eq!(solution.states.last().unwrap()[0], 10.0, max_relative = 1e-11); assert_relative_eq!(
solution.states.last().unwrap()[0],
10.0,
max_relative = 1e-11
);
} }
#[test] #[test]
@@ -212,7 +232,7 @@ mod tests {
let y0 = Vector3::new(1.0, 1.0, 1.0); let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64); let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6);
let controller = PIController::default(); let controller = PIController::default();
let mut problem = Problem::new(ode, dp45, controller); let mut problem = Problem::new(ode, dp45, controller);