From 9d82f92c070f53cf6da40e3974a9f13582294d2a Mon Sep 17 00:00:00 2001 From: Connor Johnstone Date: Mon, 16 Oct 2023 14:16:44 -0600 Subject: [PATCH] Formatting changes --- src/callback.rs | 8 +-- src/controller.rs | 21 +++++-- src/integrator/dormand_prince.rs | 46 ++++++++------- src/integrator/mod.rs | 26 ++++++--- src/lib.rs | 68 +++++++++++++++------- src/ode.rs | 29 ++++++---- src/problem.rs | 99 +++++++++++++++++++++----------- 7 files changed, 194 insertions(+), 103 deletions(-) diff --git a/src/callback.rs b/src/callback.rs index b559bac..d6080ae 100644 --- a/src/callback.rs +++ b/src/callback.rs @@ -8,14 +8,14 @@ use super::ode::ODE; #[derive(Clone, Copy)] pub struct Callback<'a, const D: usize, P> { /// The function to check for zero crossings - pub event: &'a dyn Fn(f64, SVector, &P) -> f64, + pub event: &'a dyn Fn(f64, SVector, &P) -> f64, /// The function to change the ODE - pub effect: &'a dyn Fn(&mut ODE) -> (), + pub effect: &'a dyn Fn(&mut ODE) -> (), } /// A convenience function for stopping the integration -pub fn stop<'a, const D: usize, P>(ode: &mut ODE) -> () { +pub fn stop<'a, const D: usize, P>(ode: &mut ODE) -> () { ode.t_end = ode.t; } @@ -27,7 +27,7 @@ mod tests { fn test_basic_callbacks() { type Params = (); let _value_too_high = Callback { - event: &|_: f64, y: SVector, _p: &Params| { 10.0 - y[0] }, + event: &|_: f64, y: SVector, _p: &Params| 10.0 - y[0], effect: &stop, }; } diff --git a/src/controller.rs b/src/controller.rs index a0a8a42..5ce917b 100644 --- a/src/controller.rs +++ b/src/controller.rs @@ -14,12 +14,15 @@ pub struct PIController { pub old_h: f64, } -impl Controller for PIController { +impl Controller for PIController { /// Determines if the previously run step size and error were valid or not. Either way, it also /// returns what the next step size should be fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64) { let factor_11 = err.powf(self.alpha); - let factor = self.factor_c2.max(self.factor_c1.min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor)); + let factor = self.factor_c2.max( + self.factor_c1 + .min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor), + ); let mut h_new = h / factor; if err <= 1.0 { // Accept the stepsize @@ -39,7 +42,15 @@ impl Controller for PIController { } impl PIController { - pub fn new(alpha:f64, beta:f64, max_factor: f64, min_factor: f64, h_max: f64, safety_factor: f64, initial_h: f64) -> Self { + pub fn new( + alpha: f64, + beta: f64, + max_factor: f64, + min_factor: f64, + h_max: f64, + safety_factor: f64, + initial_h: f64, + ) -> Self { Self { alpha: alpha, beta: beta, @@ -66,8 +77,8 @@ mod tests { assert!(controller.alpha == 0.17); assert!(controller.beta == 0.04); - assert!(controller.factor_c1 == 1.0/0.2); - assert!(controller.factor_c2 == 1.0/10.0); + assert!(controller.factor_c1 == 1.0 / 0.2); + assert!(controller.factor_c2 == 1.0 / 10.0); assert!(controller.factor_old == 1.0e-4); assert!(controller.h_max == 10.0); assert!(controller.safety_factor == 0.9); diff --git a/src/integrator/dormand_prince.rs b/src/integrator/dormand_prince.rs index 6953000..06c1ddf 100644 --- a/src/integrator/dormand_prince.rs +++ b/src/integrator/dormand_prince.rs @@ -17,12 +17,12 @@ pub struct DormandPrince45 { r_tol: f64, } -impl DormandPrince45 where DormandPrince45: Integrator { +impl DormandPrince45 +where + DormandPrince45: Integrator, +{ pub fn new(a_tol: f64, r_tol: f64) -> Self { - Self { - a_tol, - r_tol, - } + Self { a_tol, r_tol } } } @@ -66,15 +66,7 @@ impl<'a, const D: usize> DormandPrinceIntegrator<'a> for DormandPrince45 { 187.0 / 2_100.0, 1.0 / 40.0, ]; - const C: &'a [f64] = &[ - 0.0, - 1.0 / 5.0, - 3.0 / 10.0, - 4.0 / 5.0, - 8.0 / 9.0, - 1.0, - 1.0, - ]; + const C: &'a [f64] = &[0.0, 1.0 / 5.0, 3.0 / 10.0, 4.0 / 5.0, 8.0 / 9.0, 1.0, 1.0]; const D: &'a [f64] = &[ -12715105075.0 / 11282082432.0, 0.0, @@ -95,8 +87,12 @@ where const ADAPTIVE: bool = true; const DENSE: bool = true; - fn step

(&self, ode: &ODE, h: f64) -> (SVector, Option, Option>>) { - let mut k: Vec> = vec![SVector::::zeros(); Self::STAGES]; + fn step

( + &self, + ode: &ODE, + h: f64, + ) -> (SVector, Option, Option>>) { + let mut k: Vec> = vec![SVector::::zeros(); Self::STAGES]; let mut next_y = ode.y.clone(); let mut err = SVector::::zeros(); let mut rcont5 = SVector::::zeros(); @@ -109,9 +105,9 @@ where // Then the rest for i in 1..Self::STAGES { // Compute the ks - let mut y_term = SVector::::zeros(); + let mut y_term = SVector::::zeros(); for j in 0..i { - y_term += k[j] * Self::A[( i * (i - 1) ) / 2 + j]; + y_term += k[j] * Self::A[(i * (i - 1)) / 2 + j]; } k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h, &ode.params); @@ -123,12 +119,18 @@ where let rcont2 = next_y - ode.y; let rcont3 = h * k[0] - rcont2; let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3; - let tol = SVector::::repeat(self.a_tol) + ode.y * self.r_tol; - let rcont = vec![ rcont1, rcont2, rcont3, rcont4, rcont5, ]; + let tol = SVector::::repeat(self.a_tol) + ode.y * self.r_tol; + let rcont = vec![rcont1, rcont2, rcont3, rcont4, rcont5]; (next_y, Some((err.component_div(&tol)).norm()), Some(rcont)) } - fn interpolate(&self, t_start: f64, t_end: f64, dense: &Vec>, t: f64) -> SVector { - let s = (t - t_start)/(t_end - t_start); + fn interpolate( + &self, + t_start: f64, + t_end: f64, + dense: &Vec>, + t: f64, + ) -> SVector { + let s = (t - t_start) / (t_end - t_start); let s1 = 1.0 - s; dense[0] + (dense[1] + (dense[2] + (dense[3] + dense[4] * s1) * s) * s1) * s } diff --git a/src/integrator/mod.rs b/src/integrator/mod.rs index 8d3e1c7..c625b1b 100644 --- a/src/integrator/mod.rs +++ b/src/integrator/mod.rs @@ -13,23 +13,33 @@ pub trait Integrator { const DENSE: bool; /// Returns a new y value, then possibly an error value, and possibly a dense output /// coefficient set - fn step

(&self, ode: &ODE, h: f64) -> (SVector, Option, Option>>); - fn interpolate(&self, t_start: f64, t_end: f64, dense: &Vec>, t: f64) -> SVector; + fn step

( + &self, + ode: &ODE, + h: f64, + ) -> (SVector, Option, Option>>); + fn interpolate( + &self, + t_start: f64, + t_end: f64, + dense: &Vec>, + t: f64, + ) -> SVector; } - #[cfg(test)] mod tests { - use super::*; use super::dormand_prince::*; - use nalgebra::Vector3; + use super::*; use approx::assert_relative_eq; - + use nalgebra::Vector3; #[test] fn test_dopri() { type Params = (); - fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { y } + fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { + y + } let y0 = Vector3::new(1.0, 1.0, 1.0); let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ()); @@ -43,7 +53,7 @@ mod tests { let (new_y, err, _) = dp45.step(&ode, step); ode.y = new_y; ode.t += step; - assert_relative_eq!(ode.y[0], ode.t.exp(), max_relative=0.01); + assert_relative_eq!(ode.y[0], ode.t.exp(), max_relative = 0.01); assert!(err.unwrap() < 1.0); } } diff --git a/src/lib.rs b/src/lib.rs index 8ee27ec..da32822 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,24 +1,24 @@ #![allow(dead_code)] -pub mod ode; -pub mod integrator; -pub mod controller; pub mod callback; +pub mod controller; +pub mod integrator; +pub mod ode; pub mod problem; pub mod prelude { - pub use super::ode::ODE; - pub use super::integrator::dormand_prince::DormandPrince45; + pub use super::callback::{stop, Callback}; pub use super::controller::PIController; - pub use super::callback::{Callback, stop}; + pub use super::integrator::dormand_prince::DormandPrince45; + pub use super::ode::ODE; pub use super::problem::{Problem, Solution}; } #[cfg(test)] mod tests { - use nalgebra::{Vector2, Vector6}; - use approx::assert_relative_eq; use crate::prelude::*; + use approx::assert_relative_eq; + use nalgebra::{Vector2, Vector6}; use std::f64::consts::PI; #[test] @@ -31,10 +31,10 @@ mod tests { let &(g, l) = p; let theta = y[0]; let d_theta = y[1]; - Vector2::new( d_theta, -(g/l) * theta.sin() ) + Vector2::new(d_theta, -(g / l) * theta.sin()) } - let y0 = Vector2::new(0.0, PI/2.0); + let y0 = Vector2::new(0.0, PI / 2.0); // Set up the problem (ODE, Integrator, Controller, and Callbacks) let ode = ODE::new(&derivative, 0.0, 6.3, y0, params); @@ -42,7 +42,7 @@ mod tests { let controller = PIController::default(); let value_too_high = Callback { - event: &|t: f64, _y: Vector2, _p: &Params| { 5.0 - t }, + event: &|t: f64, _y: Vector2, _p: &Params| 5.0 - t, effect: &stop, }; @@ -59,7 +59,7 @@ mod tests { // Calculate one period let a = 6.7781363e6_f64; let mu = 3.98600441500000e14; - let period = 2.0 * PI * (a.powi(3)/mu).sqrt(); + let period = 2.0 * PI * (a.powi(3) / mu).sqrt(); // Set up the system type Params = (f64,); @@ -78,7 +78,7 @@ mod tests { ); // Integrate - let ode = ODE::new(&derivative, 0.0, 10.0*period, y0, params); + let ode = ODE::new(&derivative, 0.0, 10.0 * period, y0, params); let dp45 = DormandPrince45::new(1e-12_f64, 1e-12_f64); let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); @@ -86,12 +86,40 @@ mod tests { let solution = problem.solve(); - assert_relative_eq!(solution.times[solution.states.len()-1], 10.0 * period, max_relative=1e-12); - assert_relative_eq!(solution.states[solution.states.len()-1][0], y0[0], max_relative=1e-9); - assert_relative_eq!(solution.states[solution.states.len()-1][1], y0[1], max_relative=1e-9); - assert_relative_eq!(solution.states[solution.states.len()-1][2], y0[2], max_relative=1e-9); - assert_relative_eq!(solution.states[solution.states.len()-1][3], y0[3], max_relative=1e-9); - assert_relative_eq!(solution.states[solution.states.len()-1][4], y0[4], max_relative=1e-9); - assert_relative_eq!(solution.states[solution.states.len()-1][5], y0[5], max_relative=1e-9); + assert_relative_eq!( + solution.times[solution.states.len() - 1], + 10.0 * period, + max_relative = 1e-12 + ); + assert_relative_eq!( + solution.states[solution.states.len() - 1][0], + y0[0], + max_relative = 1e-9 + ); + assert_relative_eq!( + solution.states[solution.states.len() - 1][1], + y0[1], + max_relative = 1e-9 + ); + assert_relative_eq!( + solution.states[solution.states.len() - 1][2], + y0[2], + max_relative = 1e-9 + ); + assert_relative_eq!( + solution.states[solution.states.len() - 1][3], + y0[3], + max_relative = 1e-9 + ); + assert_relative_eq!( + solution.states[solution.states.len() - 1][4], + y0[4], + max_relative = 1e-9 + ); + assert_relative_eq!( + solution.states[solution.states.len() - 1][5], + y0[5], + max_relative = 1e-9 + ); } } diff --git a/src/ode.rs b/src/ode.rs index a5b366f..5ae10b4 100644 --- a/src/ode.rs +++ b/src/ode.rs @@ -4,8 +4,8 @@ use nalgebra::SVector; /// determined upon creation of the object #[derive(Clone, Copy)] pub struct ODE<'a, const D: usize, P> { - pub f: &'a dyn Fn(f64, SVector, &P) -> SVector, - pub y: SVector, + pub f: &'a dyn Fn(f64, SVector, &P) -> SVector, + pub y: SVector, pub t: f64, pub params: P, pub t0: f64, @@ -14,28 +14,27 @@ pub struct ODE<'a, const D: usize, P> { pub finished: bool, } -impl<'a, const D: usize, P> ODE<'a,D, P> { +impl<'a, const D: usize, P> ODE<'a, D, P> { pub fn new( - f: &'a (dyn Fn(f64, SVector, &P) -> SVector), + f: &'a (dyn Fn(f64, SVector, &P) -> SVector), t0: f64, t_end: f64, - y0: SVector, + y0: SVector, params: P, ) -> Self { Self { - f: f, + f, y: y0, t: t0, - params: params, - t0: t0, - t_end: t_end, + params, + t0, + t_end, h: 0.001, finished: false, } } } - #[cfg(test)] mod tests { use super::*; @@ -44,7 +43,9 @@ mod tests { #[test] fn test_ode_creation() { type Params = (); - fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { -y } + fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { + -y + } let y0 = Vector3::new(1.0, 0.0, 0.0); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); @@ -62,7 +63,11 @@ mod tests { let params = (34.0, true); fn derivative(t: f64, y: Vector3, p: &Params) -> Vector3 { - if p.1 { -y } else { y * t } + if p.1 { + -y + } else { + y * t + } } let y0 = Vector3::new(1.0, 0.0, 0.0); diff --git a/src/problem.rs b/src/problem.rs index 08e12e5..5e2f628 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -1,10 +1,10 @@ use nalgebra::SVector; use roots::find_root_regula_falsi; -use super::ode::ODE; +use super::callback::Callback; use super::controller::{Controller, PIController}; use super::integrator::Integrator; -use super::callback::Callback; +use super::ode::ODE; #[derive(Clone)] pub struct Problem<'a, const D: usize, S, P> @@ -17,11 +17,11 @@ where callbacks: Vec>, } -impl<'a, const D: usize, S, P> Problem<'a,D,S,P> +impl<'a, const D: usize, S, P> Problem<'a, D, S, P> where S: Integrator + Copy, { - pub fn new(ode: ODE<'a,D,P>, integrator: S, controller: PIController) -> Self { + pub fn new(ode: ODE<'a, D, P>, integrator: S, controller: PIController) -> Self { Problem { ode: ode, integrator: integrator, @@ -30,24 +30,31 @@ where } } pub fn solve(&mut self) -> Solution { - let mut times: Vec:: = vec![self.ode.t]; - let mut states: Vec::> = vec![self.ode.y]; - let mut dense_coefficients: Vec::>> = Vec::new(); + let mut times: Vec = vec![self.ode.t]; + let mut states: Vec> = vec![self.ode.y]; + let mut dense_coefficients: Vec>> = Vec::new(); let mut step: f64 = self.controller.old_h; let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0); while self.ode.t < self.ode.t_end { - let mut dense_option: Option>> = None; + let mut dense_option: Option>> = None; if S::ADAPTIVE { let mut err = err_option.unwrap(); let mut accepted: bool = false; while !accepted { // Try a step and if that isn't acceptable, then change the step until it is - (accepted, step) = >::determine_step(&mut self.controller, step, err); + (accepted, step) = >::determine_step( + &mut self.controller, + step, + err, + ); (new_y, err_option, dense_option) = self.integrator.step(&self.ode, step); err = err_option.unwrap(); } self.controller.old_h = step; - self.controller.h_max = self.controller.h_max.min(self.ode.t_end - self.ode.t - step); + self.controller.h_max = self + .controller + .h_max + .min(self.ode.t_end - self.ode.t - step); } else { // If fixed time step just step forward one step (new_y, _, dense_option) = self.integrator.step(&self.ode, step); @@ -55,7 +62,10 @@ where if self.callbacks.len() > 0 { // Check for events occurring for callback in &self.callbacks { - if (callback.event)(self.ode.t, self.ode.y, &self.ode.params) * (callback.event)(self.ode.t + step, new_y, &self.ode.params) < 0.0 { + if (callback.event)(self.ode.t, self.ode.y, &self.ode.params) + * (callback.event)(self.ode.t + step, new_y, &self.ode.params) + < 0.0 + { // If the event crossed zero, then find the root let f = |test_t| { let test_y = self.integrator.step(&self.ode, test_t).0; @@ -77,8 +87,8 @@ where } Solution { integrator: self.integrator, - times: times, - states: states, + times, + states, dense: dense_coefficients, } } @@ -94,14 +104,20 @@ where } } -pub struct Solution where S: Integrator { +pub struct Solution +where + S: Integrator, +{ pub integrator: S, pub times: Vec, - pub states: Vec>, - pub dense: Vec::>>, + pub states: Vec>, + pub dense: Vec>>, } -impl Solution where S: Integrator { +impl Solution +where + S: Integrator, +{ pub fn interpolate(&self, t: f64) -> SVector { // First check that the t is within bounds let last = self.times.last().unwrap(); @@ -109,8 +125,12 @@ impl Solution where S: Integrator { // TODO: Improve these errors let mut times = self.times.clone(); - if *first > *last { times.reverse(); } - if t < *first || t > *last { panic!(); } + if *first > *last { + times.reverse(); + } + if t < *first || t > *last { + panic!(); + } // Then find the two t values closest to the desired t let mut end_index: usize = 0; @@ -124,23 +144,26 @@ impl Solution where S: Integrator { // Then send that to the integrator let t_start = times[end_index - 1]; let t_end = times[end_index]; - self.integrator.interpolate(t_start, t_end, &self.dense[end_index - 1], t) + self.integrator + .interpolate(t_start, t_end, &self.dense[end_index - 1], t) } } #[cfg(test)] mod tests { use super::*; - use nalgebra::Vector3; - use approx::assert_relative_eq; - use crate::integrator::dormand_prince::DormandPrince45; - use crate::controller::PIController; use crate::callback::stop; + use crate::controller::PIController; + use crate::integrator::dormand_prince::DormandPrince45; + use approx::assert_relative_eq; + use nalgebra::Vector3; #[test] fn test_problem() { type Params = (); - fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { y } + fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { + y + } let y0 = Vector3::new(1.0, 1.0, 1.0); let ode = ODE::new(&derivative, 0.0, 1.0, y0, ()); @@ -150,15 +173,21 @@ mod tests { let mut problem = Problem::new(ode, dp45, controller); let solution = problem.solve(); - solution.times.iter().zip(solution.states.iter()).for_each(|(time, state)| { - assert_relative_eq!(state[0], time.exp(), max_relative=1e-2); - }) + solution + .times + .iter() + .zip(solution.states.iter()) + .for_each(|(time, state)| { + assert_relative_eq!(state[0], time.exp(), max_relative = 1e-2); + }) } #[test] fn test_with_callback() { type Params = (); - fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { y } + fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { + y + } let y0 = Vector3::new(1.0, 1.0, 1.0); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); @@ -166,7 +195,7 @@ mod tests { let controller = PIController::default(); let value_too_high = Callback { - event: &|_: f64, y: SVector, _: &Params| { 10.0 - y[0] }, + event: &|_: f64, y: SVector, _: &Params| 10.0 - y[0], effect: &stop, }; @@ -179,7 +208,9 @@ mod tests { #[test] fn test_with_interpolation() { type Params = (); - fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { y } + fn derivative(_t: f64, y: Vector3, _p: &Params) -> Vector3 { + y + } let y0 = Vector3::new(1.0, 1.0, 1.0); let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); @@ -189,6 +220,10 @@ mod tests { let mut problem = Problem::new(ode, dp45, controller); let solution = problem.solve(); - assert_relative_eq!(solution.interpolate(8.8)[0], 8.8_f64.exp(), max_relative=1e-6); + assert_relative_eq!( + solution.interpolate(8.8)[0], + 8.8_f64.exp(), + max_relative = 1e-6 + ); } }