Formatting changes

This commit is contained in:
Connor Johnstone
2023-10-16 14:16:44 -06:00
parent 16e958c5db
commit 9d82f92c07
7 changed files with 194 additions and 103 deletions

View File

@@ -27,7 +27,7 @@ mod tests {
fn test_basic_callbacks() { fn test_basic_callbacks() {
type Params = (); type Params = ();
let _value_too_high = Callback { let _value_too_high = Callback {
event: &|_: f64, y: SVector<f64,3>, _p: &Params| { 10.0 - y[0] }, event: &|_: f64, y: SVector<f64, 3>, _p: &Params| 10.0 - y[0],
effect: &stop, effect: &stop,
}; };
} }

View File

@@ -19,7 +19,10 @@ impl<const D:usize> Controller<D> for PIController {
/// returns what the next step size should be /// returns what the next step size should be
fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64) { fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64) {
let factor_11 = err.powf(self.alpha); let factor_11 = err.powf(self.alpha);
let factor = self.factor_c2.max(self.factor_c1.min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor)); let factor = self.factor_c2.max(
self.factor_c1
.min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor),
);
let mut h_new = h / factor; let mut h_new = h / factor;
if err <= 1.0 { if err <= 1.0 {
// Accept the stepsize // Accept the stepsize
@@ -39,7 +42,15 @@ impl<const D:usize> Controller<D> for PIController {
} }
impl PIController { impl PIController {
pub fn new(alpha:f64, beta:f64, max_factor: f64, min_factor: f64, h_max: f64, safety_factor: f64, initial_h: f64) -> Self { pub fn new(
alpha: f64,
beta: f64,
max_factor: f64,
min_factor: f64,
h_max: f64,
safety_factor: f64,
initial_h: f64,
) -> Self {
Self { Self {
alpha: alpha, alpha: alpha,
beta: beta, beta: beta,

View File

@@ -17,12 +17,12 @@ pub struct DormandPrince45<const D: usize> {
r_tol: f64, r_tol: f64,
} }
impl<const D: usize> DormandPrince45<D> where DormandPrince45<D>: Integrator<D> { impl<const D: usize> DormandPrince45<D>
where
DormandPrince45<D>: Integrator<D>,
{
pub fn new(a_tol: f64, r_tol: f64) -> Self { pub fn new(a_tol: f64, r_tol: f64) -> Self {
Self { Self { a_tol, r_tol }
a_tol,
r_tol,
}
} }
} }
@@ -66,15 +66,7 @@ impl<'a, const D: usize> DormandPrinceIntegrator<'a> for DormandPrince45<D> {
187.0 / 2_100.0, 187.0 / 2_100.0,
1.0 / 40.0, 1.0 / 40.0,
]; ];
const C: &'a [f64] = &[ const C: &'a [f64] = &[0.0, 1.0 / 5.0, 3.0 / 10.0, 4.0 / 5.0, 8.0 / 9.0, 1.0, 1.0];
0.0,
1.0 / 5.0,
3.0 / 10.0,
4.0 / 5.0,
8.0 / 9.0,
1.0,
1.0,
];
const D: &'a [f64] = &[ const D: &'a [f64] = &[
-12715105075.0 / 11282082432.0, -12715105075.0 / 11282082432.0,
0.0, 0.0,
@@ -95,8 +87,12 @@ where
const ADAPTIVE: bool = true; const ADAPTIVE: bool = true;
const DENSE: bool = true; const DENSE: bool = true;
fn step<P>(&self, ode: &ODE<D,P>, h: f64) -> (SVector<f64,D>, Option<f64>, Option<Vec<SVector<f64, D>>>) { fn step<P>(
let mut k: Vec<SVector::<f64,D>> = vec![SVector::<f64,D>::zeros(); Self::STAGES]; &self,
ode: &ODE<D, P>,
h: f64,
) -> (SVector<f64, D>, Option<f64>, Option<Vec<SVector<f64, D>>>) {
let mut k: Vec<SVector<f64, D>> = vec![SVector::<f64, D>::zeros(); Self::STAGES];
let mut next_y = ode.y.clone(); let mut next_y = ode.y.clone();
let mut err = SVector::<f64, D>::zeros(); let mut err = SVector::<f64, D>::zeros();
let mut rcont5 = SVector::<f64, D>::zeros(); let mut rcont5 = SVector::<f64, D>::zeros();
@@ -124,10 +120,16 @@ where
let rcont3 = h * k[0] - rcont2; let rcont3 = h * k[0] - rcont2;
let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3; let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3;
let tol = SVector::<f64, D>::repeat(self.a_tol) + ode.y * self.r_tol; let tol = SVector::<f64, D>::repeat(self.a_tol) + ode.y * self.r_tol;
let rcont = vec![ rcont1, rcont2, rcont3, rcont4, rcont5, ]; let rcont = vec![rcont1, rcont2, rcont3, rcont4, rcont5];
(next_y, Some((err.component_div(&tol)).norm()), Some(rcont)) (next_y, Some((err.component_div(&tol)).norm()), Some(rcont))
} }
fn interpolate(&self, t_start: f64, t_end: f64, dense: &Vec<SVector<f64,D>>, t: f64) -> SVector<f64,D> { fn interpolate(
&self,
t_start: f64,
t_end: f64,
dense: &Vec<SVector<f64, D>>,
t: f64,
) -> SVector<f64, D> {
let s = (t - t_start) / (t_end - t_start); let s = (t - t_start) / (t_end - t_start);
let s1 = 1.0 - s; let s1 = 1.0 - s;
dense[0] + (dense[1] + (dense[2] + (dense[3] + dense[4] * s1) * s) * s1) * s dense[0] + (dense[1] + (dense[2] + (dense[3] + dense[4] * s1) * s) * s1) * s

View File

@@ -13,23 +13,33 @@ pub trait Integrator<const D: usize> {
const DENSE: bool; const DENSE: bool;
/// Returns a new y value, then possibly an error value, and possibly a dense output /// Returns a new y value, then possibly an error value, and possibly a dense output
/// coefficient set /// coefficient set
fn step<P>(&self, ode: &ODE<D,P>, h: f64) -> (SVector<f64,D>, Option<f64>, Option<Vec<SVector<f64, D>>>); fn step<P>(
fn interpolate(&self, t_start: f64, t_end: f64, dense: &Vec<SVector<f64,D>>, t: f64) -> SVector<f64,D>; &self,
ode: &ODE<D, P>,
h: f64,
) -> (SVector<f64, D>, Option<f64>, Option<Vec<SVector<f64, D>>>);
fn interpolate(
&self,
t_start: f64,
t_end: f64,
dense: &Vec<SVector<f64, D>>,
t: f64,
) -> SVector<f64, D>;
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use super::dormand_prince::*; use super::dormand_prince::*;
use nalgebra::Vector3; use super::*;
use approx::assert_relative_eq; use approx::assert_relative_eq;
use nalgebra::Vector3;
#[test] #[test]
fn test_dopri() { fn test_dopri() {
type Params = (); type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { y } fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
y
}
let y0 = Vector3::new(1.0, 1.0, 1.0); let y0 = Vector3::new(1.0, 1.0, 1.0);
let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ()); let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ());

View File

@@ -1,24 +1,24 @@
#![allow(dead_code)] #![allow(dead_code)]
pub mod ode;
pub mod integrator;
pub mod controller;
pub mod callback; pub mod callback;
pub mod controller;
pub mod integrator;
pub mod ode;
pub mod problem; pub mod problem;
pub mod prelude { pub mod prelude {
pub use super::ode::ODE; pub use super::callback::{stop, Callback};
pub use super::integrator::dormand_prince::DormandPrince45;
pub use super::controller::PIController; pub use super::controller::PIController;
pub use super::callback::{Callback, stop}; pub use super::integrator::dormand_prince::DormandPrince45;
pub use super::ode::ODE;
pub use super::problem::{Problem, Solution}; pub use super::problem::{Problem, Solution};
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use nalgebra::{Vector2, Vector6};
use approx::assert_relative_eq;
use crate::prelude::*; use crate::prelude::*;
use approx::assert_relative_eq;
use nalgebra::{Vector2, Vector6};
use std::f64::consts::PI; use std::f64::consts::PI;
#[test] #[test]
@@ -42,7 +42,7 @@ mod tests {
let controller = PIController::default(); let controller = PIController::default();
let value_too_high = Callback { let value_too_high = Callback {
event: &|t: f64, _y: Vector2<f64>, _p: &Params| { 5.0 - t }, event: &|t: f64, _y: Vector2<f64>, _p: &Params| 5.0 - t,
effect: &stop, effect: &stop,
}; };
@@ -86,12 +86,40 @@ mod tests {
let solution = problem.solve(); let solution = problem.solve();
assert_relative_eq!(solution.times[solution.states.len()-1], 10.0 * period, max_relative=1e-12); assert_relative_eq!(
assert_relative_eq!(solution.states[solution.states.len()-1][0], y0[0], max_relative=1e-9); solution.times[solution.states.len() - 1],
assert_relative_eq!(solution.states[solution.states.len()-1][1], y0[1], max_relative=1e-9); 10.0 * period,
assert_relative_eq!(solution.states[solution.states.len()-1][2], y0[2], max_relative=1e-9); max_relative = 1e-12
assert_relative_eq!(solution.states[solution.states.len()-1][3], y0[3], max_relative=1e-9); );
assert_relative_eq!(solution.states[solution.states.len()-1][4], y0[4], max_relative=1e-9); assert_relative_eq!(
assert_relative_eq!(solution.states[solution.states.len()-1][5], y0[5], max_relative=1e-9); solution.states[solution.states.len() - 1][0],
y0[0],
max_relative = 1e-9
);
assert_relative_eq!(
solution.states[solution.states.len() - 1][1],
y0[1],
max_relative = 1e-9
);
assert_relative_eq!(
solution.states[solution.states.len() - 1][2],
y0[2],
max_relative = 1e-9
);
assert_relative_eq!(
solution.states[solution.states.len() - 1][3],
y0[3],
max_relative = 1e-9
);
assert_relative_eq!(
solution.states[solution.states.len() - 1][4],
y0[4],
max_relative = 1e-9
);
assert_relative_eq!(
solution.states[solution.states.len() - 1][5],
y0[5],
max_relative = 1e-9
);
} }
} }

View File

@@ -23,19 +23,18 @@ impl<'a, const D: usize, P> ODE<'a,D, P> {
params: P, params: P,
) -> Self { ) -> Self {
Self { Self {
f: f, f,
y: y0, y: y0,
t: t0, t: t0,
params: params, params,
t0: t0, t0,
t_end: t_end, t_end,
h: 0.001, h: 0.001,
finished: false, finished: false,
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -44,7 +43,9 @@ mod tests {
#[test] #[test]
fn test_ode_creation() { fn test_ode_creation() {
type Params = (); type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { -y } fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
-y
}
let y0 = Vector3::new(1.0, 0.0, 0.0); let y0 = Vector3::new(1.0, 0.0, 0.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
@@ -62,7 +63,11 @@ mod tests {
let params = (34.0, true); let params = (34.0, true);
fn derivative(t: f64, y: Vector3<f64>, p: &Params) -> Vector3<f64> { fn derivative(t: f64, y: Vector3<f64>, p: &Params) -> Vector3<f64> {
if p.1 { -y } else { y * t } if p.1 {
-y
} else {
y * t
}
} }
let y0 = Vector3::new(1.0, 0.0, 0.0); let y0 = Vector3::new(1.0, 0.0, 0.0);

View File

@@ -1,10 +1,10 @@
use nalgebra::SVector; use nalgebra::SVector;
use roots::find_root_regula_falsi; use roots::find_root_regula_falsi;
use super::ode::ODE; use super::callback::Callback;
use super::controller::{Controller, PIController}; use super::controller::{Controller, PIController};
use super::integrator::Integrator; use super::integrator::Integrator;
use super::callback::Callback; use super::ode::ODE;
#[derive(Clone)] #[derive(Clone)]
pub struct Problem<'a, const D: usize, S, P> pub struct Problem<'a, const D: usize, S, P>
@@ -30,9 +30,9 @@ where
} }
} }
pub fn solve(&mut self) -> Solution<S, D> { pub fn solve(&mut self) -> Solution<S, D> {
let mut times: Vec::<f64> = vec![self.ode.t]; let mut times: Vec<f64> = vec![self.ode.t];
let mut states: Vec::<SVector<f64,D>> = vec![self.ode.y]; let mut states: Vec<SVector<f64, D>> = vec![self.ode.y];
let mut dense_coefficients: Vec::<Vec<SVector<f64,D>>> = Vec::new(); let mut dense_coefficients: Vec<Vec<SVector<f64, D>>> = Vec::new();
let mut step: f64 = self.controller.old_h; let mut step: f64 = self.controller.old_h;
let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0); let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0);
while self.ode.t < self.ode.t_end { while self.ode.t < self.ode.t_end {
@@ -42,12 +42,19 @@ where
let mut accepted: bool = false; let mut accepted: bool = false;
while !accepted { while !accepted {
// Try a step and if that isn't acceptable, then change the step until it is // Try a step and if that isn't acceptable, then change the step until it is
(accepted, step) = <PIController as Controller<D>>::determine_step(&mut self.controller, step, err); (accepted, step) = <PIController as Controller<D>>::determine_step(
&mut self.controller,
step,
err,
);
(new_y, err_option, dense_option) = self.integrator.step(&self.ode, step); (new_y, err_option, dense_option) = self.integrator.step(&self.ode, step);
err = err_option.unwrap(); err = err_option.unwrap();
} }
self.controller.old_h = step; self.controller.old_h = step;
self.controller.h_max = self.controller.h_max.min(self.ode.t_end - self.ode.t - step); self.controller.h_max = self
.controller
.h_max
.min(self.ode.t_end - self.ode.t - step);
} else { } else {
// If fixed time step just step forward one step // If fixed time step just step forward one step
(new_y, _, dense_option) = self.integrator.step(&self.ode, step); (new_y, _, dense_option) = self.integrator.step(&self.ode, step);
@@ -55,7 +62,10 @@ where
if self.callbacks.len() > 0 { if self.callbacks.len() > 0 {
// Check for events occurring // Check for events occurring
for callback in &self.callbacks { for callback in &self.callbacks {
if (callback.event)(self.ode.t, self.ode.y, &self.ode.params) * (callback.event)(self.ode.t + step, new_y, &self.ode.params) < 0.0 { if (callback.event)(self.ode.t, self.ode.y, &self.ode.params)
* (callback.event)(self.ode.t + step, new_y, &self.ode.params)
< 0.0
{
// If the event crossed zero, then find the root // If the event crossed zero, then find the root
let f = |test_t| { let f = |test_t| {
let test_y = self.integrator.step(&self.ode, test_t).0; let test_y = self.integrator.step(&self.ode, test_t).0;
@@ -77,8 +87,8 @@ where
} }
Solution { Solution {
integrator: self.integrator, integrator: self.integrator,
times: times, times,
states: states, states,
dense: dense_coefficients, dense: dense_coefficients,
} }
} }
@@ -94,14 +104,20 @@ where
} }
} }
pub struct Solution<S, const D: usize> where S: Integrator<D> { pub struct Solution<S, const D: usize>
where
S: Integrator<D>,
{
pub integrator: S, pub integrator: S,
pub times: Vec<f64>, pub times: Vec<f64>,
pub states: Vec<SVector<f64, D>>, pub states: Vec<SVector<f64, D>>,
pub dense: Vec::<Vec<SVector<f64,D>>>, pub dense: Vec<Vec<SVector<f64, D>>>,
} }
impl<S, const D: usize> Solution<S,D> where S: Integrator<D> { impl<S, const D: usize> Solution<S, D>
where
S: Integrator<D>,
{
pub fn interpolate(&self, t: f64) -> SVector<f64, D> { pub fn interpolate(&self, t: f64) -> SVector<f64, D> {
// First check that the t is within bounds // First check that the t is within bounds
let last = self.times.last().unwrap(); let last = self.times.last().unwrap();
@@ -109,8 +125,12 @@ impl<S, const D: usize> Solution<S,D> where S: Integrator<D> {
// TODO: Improve these errors // TODO: Improve these errors
let mut times = self.times.clone(); let mut times = self.times.clone();
if *first > *last { times.reverse(); } if *first > *last {
if t < *first || t > *last { panic!(); } times.reverse();
}
if t < *first || t > *last {
panic!();
}
// Then find the two t values closest to the desired t // Then find the two t values closest to the desired t
let mut end_index: usize = 0; let mut end_index: usize = 0;
@@ -124,23 +144,26 @@ impl<S, const D: usize> Solution<S,D> where S: Integrator<D> {
// Then send that to the integrator // Then send that to the integrator
let t_start = times[end_index - 1]; let t_start = times[end_index - 1];
let t_end = times[end_index]; let t_end = times[end_index];
self.integrator.interpolate(t_start, t_end, &self.dense[end_index - 1], t) self.integrator
.interpolate(t_start, t_end, &self.dense[end_index - 1], t)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use nalgebra::Vector3;
use approx::assert_relative_eq;
use crate::integrator::dormand_prince::DormandPrince45;
use crate::controller::PIController;
use crate::callback::stop; use crate::callback::stop;
use crate::controller::PIController;
use crate::integrator::dormand_prince::DormandPrince45;
use approx::assert_relative_eq;
use nalgebra::Vector3;
#[test] #[test]
fn test_problem() { fn test_problem() {
type Params = (); type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { y } fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
y
}
let y0 = Vector3::new(1.0, 1.0, 1.0); let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 1.0, y0, ()); let ode = ODE::new(&derivative, 0.0, 1.0, y0, ());
@@ -150,7 +173,11 @@ mod tests {
let mut problem = Problem::new(ode, dp45, controller); let mut problem = Problem::new(ode, dp45, controller);
let solution = problem.solve(); let solution = problem.solve();
solution.times.iter().zip(solution.states.iter()).for_each(|(time, state)| { solution
.times
.iter()
.zip(solution.states.iter())
.for_each(|(time, state)| {
assert_relative_eq!(state[0], time.exp(), max_relative = 1e-2); assert_relative_eq!(state[0], time.exp(), max_relative = 1e-2);
}) })
} }
@@ -158,7 +185,9 @@ mod tests {
#[test] #[test]
fn test_with_callback() { fn test_with_callback() {
type Params = (); type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { y } fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
y
}
let y0 = Vector3::new(1.0, 1.0, 1.0); let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
@@ -166,7 +195,7 @@ mod tests {
let controller = PIController::default(); let controller = PIController::default();
let value_too_high = Callback { let value_too_high = Callback {
event: &|_: f64, y: SVector<f64,3>, _: &Params| { 10.0 - y[0] }, event: &|_: f64, y: SVector<f64, 3>, _: &Params| 10.0 - y[0],
effect: &stop, effect: &stop,
}; };
@@ -179,7 +208,9 @@ mod tests {
#[test] #[test]
fn test_with_interpolation() { fn test_with_interpolation() {
type Params = (); type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { y } fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
y
}
let y0 = Vector3::new(1.0, 1.0, 1.0); let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
@@ -189,6 +220,10 @@ mod tests {
let mut problem = Problem::new(ode, dp45, controller); let mut problem = Problem::new(ode, dp45, controller);
let solution = problem.solve(); let solution = problem.solve();
assert_relative_eq!(solution.interpolate(8.8)[0], 8.8_f64.exp(), max_relative=1e-6); assert_relative_eq!(
solution.interpolate(8.8)[0],
8.8_f64.exp(),
max_relative = 1e-6
);
} }
} }