Compare commits
	
		
			10 Commits
		
	
	
		
			3a15323a9c
			...
			b42f3a3e77
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | b42f3a3e77 | ||
|   | 69d2fe4336 | ||
|   | 44bb3e5ac1 | ||
|   | 0dfed1cd06 | ||
|   | 2659d78582 | ||
|   | 9075dac669 | ||
|   | e27ef0a07c | ||
|   | 0cfd4f1f5d | ||
|   | 76089fa012 | ||
|   | 5d0a7d6e84 | 
							
								
								
									
										22
									
								
								Cargo.toml
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								Cargo.toml
									
									
									
									
									
								
							| @@ -1,15 +1,29 @@ | |||||||
| [package] | [package] | ||||||
| name = "differential_equations" | name = "ordinary-diffeq" | ||||||
| version = "0.2.1" | version = "0.2.3" | ||||||
| edition = "2021" | edition = "2021" | ||||||
|  | authors = ["Connor Johnstone"] | ||||||
|  | description = "A library for solving differential equations based on the DifferentialEquations.jl julia library." | ||||||
|  | readme = "readme.md" | ||||||
|  | repository = "https://gitlab.rcjohnstone.com/connor/differential-equations" | ||||||
|  | license = "MIT" | ||||||
|  |  | ||||||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||||||
|  |  | ||||||
| [dependencies] | [dependencies] | ||||||
| serde = { version = "1.0", features = ["derive"] } | serde = { version = "1.0", features = ["derive"] } | ||||||
| nalgebra = { version = "0.32", features = ["serde-serialize"] } | nalgebra = { version = "0.34", features = ["serde-serialize"] } | ||||||
| num-traits = "0.2.15" | num-traits = "0.2.19" | ||||||
| roots = "0.0.8" | roots = "0.0.8" | ||||||
|  |  | ||||||
| [dev-dependencies] | [dev-dependencies] | ||||||
| approx = "0.5" | approx = "0.5" | ||||||
|  | criterion = "0.7.0" | ||||||
|  |  | ||||||
|  | [[bench]] | ||||||
|  | name = "simple_1d" | ||||||
|  | harness = false | ||||||
|  |  | ||||||
|  | [[bench]] | ||||||
|  | name = "orbit" | ||||||
|  | harness = false | ||||||
|   | |||||||
							
								
								
									
										40
									
								
								benches/orbit.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								benches/orbit.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | |||||||
|  | use criterion::{criterion_group, criterion_main, Criterion}; | ||||||
|  |  | ||||||
|  | use ordinary_diffeq::prelude::*; | ||||||
|  | use nalgebra::Vector6; | ||||||
|  |  | ||||||
|  | fn bench_orbit(c: &mut Criterion) { | ||||||
|  |     let mu = 3.98600441500000e14; | ||||||
|  |  | ||||||
|  |     // Set up the system | ||||||
|  |     type Params = (f64,); | ||||||
|  |     let params = (mu,); | ||||||
|  |     fn derivative(_t: f64, state: Vector6<f64>, p: &Params) -> Vector6<f64> { | ||||||
|  |         let acc = -(p.0 * state.fixed_rows::<3>(0)) / (state.fixed_rows::<3>(0).norm().powi(3)); | ||||||
|  |         Vector6::new(state[3], state[4], state[5], acc[0], acc[1], acc[2]) | ||||||
|  |     } | ||||||
|  |     let y0 = Vector6::new( | ||||||
|  |         4.263868426884883e6, | ||||||
|  |         5.146189057155391e6, | ||||||
|  |         1.1310208421331816e6, | ||||||
|  |         -5923.454461876975, | ||||||
|  |         4496.802639690076, | ||||||
|  |         1870.3893008991558, | ||||||
|  |     ); | ||||||
|  |  | ||||||
|  |     // Integrate | ||||||
|  |     let ode = ODE::new(&derivative, 0.0, 86400.0, y0, params); | ||||||
|  |     let dp45 = DormandPrince45::new(); | ||||||
|  |     let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); | ||||||
|  |  | ||||||
|  |     c.bench_function("bench_orbit", |b| { | ||||||
|  |         b.iter(|| { | ||||||
|  |             std::hint::black_box({ | ||||||
|  |                 Problem::new(ode, dp45, controller).solve(); | ||||||
|  |             }); | ||||||
|  |         }); | ||||||
|  |     }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | criterion_group!(benches, bench_orbit); | ||||||
|  | criterion_main!(benches); | ||||||
							
								
								
									
										56
									
								
								benches/simple_1d.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								benches/simple_1d.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | |||||||
|  | use criterion::{criterion_group, criterion_main, Criterion}; | ||||||
|  |  | ||||||
|  | use ordinary_diffeq::prelude::*; | ||||||
|  | use nalgebra::Vector1; | ||||||
|  |  | ||||||
|  | fn bench_simple_1d(c: &mut Criterion) { | ||||||
|  |     type Params = (f64,); | ||||||
|  |     let params = (0.1,); | ||||||
|  |  | ||||||
|  |     fn derivative(_t: f64, y: Vector1<f64>, p: &Params) -> Vector1<f64> { | ||||||
|  |         Vector1::new(-p.0 * y[0]) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     let y0 = Vector1::new(1.0); | ||||||
|  |  | ||||||
|  |     // Set up the problem (ODE, Integrator, Controller, and Callbacks) | ||||||
|  |     let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); | ||||||
|  |     let dp45 = DormandPrince45::new().a_tol(1e-6).r_tol(1e-6); | ||||||
|  |     let controller = PIController::default(); | ||||||
|  |  | ||||||
|  |     c.bench_function("bench_simple_1d", |b| { | ||||||
|  |         b.iter(|| { | ||||||
|  |             std::hint::black_box({ | ||||||
|  |                 Problem::new(ode, dp45, controller).solve(); | ||||||
|  |             }); | ||||||
|  |         }); | ||||||
|  |     }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | fn bench_interpolation_1d(c: &mut Criterion) { | ||||||
|  |     type Params = (f64,); | ||||||
|  |     let params = (0.1,); | ||||||
|  |  | ||||||
|  |     fn derivative(_t: f64, y: Vector1<f64>, p: &Params) -> Vector1<f64> { | ||||||
|  |         Vector1::new(-p.0 * y[0]) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     let y0 = Vector1::new(1.0); | ||||||
|  |  | ||||||
|  |     // Set up the problem (ODE, Integrator, Controller, and Callbacks) | ||||||
|  |     let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); | ||||||
|  |     let dp45 = DormandPrince45::new().a_tol(1e-6).r_tol(1e-6); | ||||||
|  |     let controller = PIController::default(); | ||||||
|  |  | ||||||
|  |     c.bench_function("bench_interpolation_1d", |b| { | ||||||
|  |         b.iter(|| { | ||||||
|  |             std::hint::black_box({ | ||||||
|  |                 let solution = Problem::new(ode, dp45, controller).solve(); | ||||||
|  |                 let _ = (0..100).map(|t| solution.interpolate(t as f64 * 0.1)[0]); | ||||||
|  |             }); | ||||||
|  |         }); | ||||||
|  |     }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | criterion_group!(benches, bench_simple_1d, bench_interpolation_1d,); | ||||||
|  | criterion_main!(benches); | ||||||
| @@ -53,7 +53,7 @@ let y0 = Vector2::new(0.0, PI/2.0); | |||||||
|  |  | ||||||
| // Set up the problem (ODE, Integrator, Controller, and Callbacks) | // Set up the problem (ODE, Integrator, Controller, and Callbacks) | ||||||
| let ode = ODE::new(&derivative, 0.0, 6.3, y0, params); | let ode = ODE::new(&derivative, 0.0, 6.3, y0, params); | ||||||
| let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64); | let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6); | ||||||
| let controller = PIController::default(); | let controller = PIController::default(); | ||||||
|  |  | ||||||
| let value_too_high = Callback { | let value_too_high = Callback { | ||||||
|   | |||||||
| @@ -11,11 +11,11 @@ pub struct Callback<'a, const D: usize, P> { | |||||||
|     pub event: &'a dyn Fn(f64, SVector<f64, D>, &P) -> f64, |     pub event: &'a dyn Fn(f64, SVector<f64, D>, &P) -> f64, | ||||||
|  |  | ||||||
|     /// The function to change the ODE |     /// The function to change the ODE | ||||||
|     pub effect: &'a dyn Fn(&mut ODE<D, P>) -> (), |     pub effect: &'a dyn Fn(&mut ODE<D, P>), | ||||||
| } | } | ||||||
|  |  | ||||||
| /// A convenience function for stopping the integration | /// A convenience function for stopping the integration | ||||||
| pub fn stop<'a, const D: usize, P>(ode: &mut ODE<D, P>) -> () { | pub fn stop<const D: usize, P>(ode: &mut ODE<D, P>) { | ||||||
|     ode.t_end = ode.t; |     ode.t_end = ode.t; | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,5 +1,31 @@ | |||||||
|  | #[derive(Debug, Clone, Copy, PartialEq)] | ||||||
|  | pub enum TryStep { | ||||||
|  |     Accepted(f64, f64), | ||||||
|  |     NotYetAccepted(f64), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl TryStep { | ||||||
|  |     pub fn extract(&self) -> f64 { | ||||||
|  |         match self { | ||||||
|  |             TryStep::Accepted(h, _) => *h, | ||||||
|  |             TryStep::NotYetAccepted(h) => *h, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn is_accepted(&self) -> bool { | ||||||
|  |         matches!(self, TryStep::Accepted(_, _)) | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     pub fn reset(&mut self) -> Result<TryStep, &str> { | ||||||
|  |         match self { | ||||||
|  |             TryStep::Accepted(_, h) => Ok(TryStep::NotYetAccepted(*h)), | ||||||
|  |             TryStep::NotYetAccepted(_) => Err("Cannot reset a NotYetAccepted TryStep"), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| pub trait Controller<const D: usize> { | pub trait Controller<const D: usize> { | ||||||
|     fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64); |     fn determine_step(&mut self, h: f64, err: f64) -> TryStep; | ||||||
| } | } | ||||||
|  |  | ||||||
| #[derive(Debug, Clone, Copy)] | #[derive(Debug, Clone, Copy)] | ||||||
| @@ -11,32 +37,30 @@ pub struct PIController { | |||||||
|     pub factor_old: f64, |     pub factor_old: f64, | ||||||
|     pub h_max: f64, |     pub h_max: f64, | ||||||
|     pub safety_factor: f64, |     pub safety_factor: f64, | ||||||
|     pub old_h: f64, |     pub next_step_guess: TryStep, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl<const D: usize> Controller<D> for PIController { | impl<const D: usize> Controller<D> for PIController { | ||||||
|     /// Determines if the previously run step size and error were valid or not. Either way, it also |     /// Determines if the previously run step size and error were valid or not. Either way, it also | ||||||
|     /// returns what the next step size should be |     /// returns what the next step size should be | ||||||
|     fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64) { |     fn determine_step(&mut self, prev_step: f64, err: f64) -> TryStep { | ||||||
|         let factor_11 = err.powf(self.alpha); |         let factor_11 = err.powf(self.alpha); | ||||||
|         let factor = self.factor_c2.max( |         let factor = self.factor_c2.max( | ||||||
|             self.factor_c1 |             self.factor_c1 | ||||||
|                 .min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor), |                 .min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor), | ||||||
|         ); |         ); | ||||||
|         let mut h_new = h / factor; |  | ||||||
|         if err <= 1.0 { |         if err <= 1.0 { | ||||||
|             // Accept the stepsize |             let mut h = prev_step / factor; | ||||||
|  |             // Accept the stepsize and provide what the next step size should be | ||||||
|             self.factor_old = err.max(1.0e-4); |             self.factor_old = err.max(1.0e-4); | ||||||
|             if h_new.abs() > self.h_max { |             if h.abs() > self.h_max { | ||||||
|                 // If the step is too big |                 // If the step goes past the maximum allowed, though, we shrink it | ||||||
|                 h_new = self.h_max.copysign(h_new); |                 h = self.h_max.copysign(h); | ||||||
|             } |             } | ||||||
|             (true, h_new) |             TryStep::Accepted(prev_step, h) | ||||||
|             // (true, h_new) |  | ||||||
|         } else { |         } else { | ||||||
|             // Reject the stepsize and propose a smaller one |             // Reject the stepsize and propose a smaller one for the current step | ||||||
|             h_new = h / (self.factor_c1.min(factor_11 / self.safety_factor)); |             TryStep::NotYetAccepted(prev_step / (self.factor_c1.min(factor_11 / self.safety_factor))) | ||||||
|             (false, h_new) |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -52,17 +76,20 @@ impl PIController { | |||||||
|         initial_h: f64, |         initial_h: f64, | ||||||
|     ) -> Self { |     ) -> Self { | ||||||
|         Self { |         Self { | ||||||
|             alpha: alpha, |             alpha, | ||||||
|             beta: beta, |             beta, | ||||||
|             factor_c1: 1.0 / min_factor, |             factor_c1: 1.0 / min_factor, | ||||||
|             factor_c2: 1.0 / max_factor, |             factor_c2: 1.0 / max_factor, | ||||||
|             factor_old: 1.0e-4, |             factor_old: 1.0e-4, | ||||||
|             h_max: h_max.abs(), |             h_max: h_max.abs(), | ||||||
|             safety_factor: safety_factor, |             safety_factor, | ||||||
|             old_h: initial_h, |             next_step_guess: TryStep::NotYetAccepted(initial_h), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     pub fn default() -> Self { | } | ||||||
|  |  | ||||||
|  | impl Default for PIController { | ||||||
|  |     fn default() -> Self { | ||||||
|         Self::new(0.17, 0.04, 10.0, 0.2, 100000.0, 0.9, 1e-4) |         Self::new(0.17, 0.04, 10.0, 0.2, 100000.0, 0.9, 1e-4) | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -82,6 +109,6 @@ mod tests { | |||||||
|         assert!(controller.factor_old == 1.0e-4); |         assert!(controller.factor_old == 1.0e-4); | ||||||
|         assert!(controller.h_max == 10.0); |         assert!(controller.h_max == 10.0); | ||||||
|         assert!(controller.safety_factor == 0.9); |         assert!(controller.safety_factor == 0.9); | ||||||
|         assert!(controller.old_h == 1e-4); |         assert!(controller.next_step_guess == TryStep::NotYetAccepted(1e-4)); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ pub trait DormandPrinceIntegrator<'a> { | |||||||
|  |  | ||||||
| #[derive(Debug, Clone, Copy)] | #[derive(Debug, Clone, Copy)] | ||||||
| pub struct DormandPrince45<const D: usize> { | pub struct DormandPrince45<const D: usize> { | ||||||
|     a_tol: f64, |     a_tol: SVector<f64,D>, | ||||||
|     r_tol: f64, |     r_tol: f64, | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -21,8 +21,17 @@ impl<const D: usize> DormandPrince45<D> | |||||||
| where | where | ||||||
|     DormandPrince45<D>: Integrator<D>, |     DormandPrince45<D>: Integrator<D>, | ||||||
| { | { | ||||||
|     pub fn new(a_tol: f64, r_tol: f64) -> Self { |     pub fn new() -> Self { | ||||||
|         Self { a_tol, r_tol } |         Self { a_tol: SVector::<f64,D>::from_element(1e-8), r_tol: 1e-8 } | ||||||
|  |     } | ||||||
|  |     pub fn a_tol(&mut self, a_tol: f64) -> Self { | ||||||
|  |         Self { a_tol: SVector::<f64,D>::from_element(a_tol), r_tol: self.r_tol } | ||||||
|  |     } | ||||||
|  |     pub fn a_tol_full(&mut self, a_tol: SVector::<f64,D>) -> Self { | ||||||
|  |         Self { a_tol, r_tol: self.r_tol } | ||||||
|  |     } | ||||||
|  |     pub fn r_tol(&mut self, r_tol: f64) -> Self { | ||||||
|  |         Self { a_tol: self.a_tol, r_tol } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -93,7 +102,7 @@ where | |||||||
|         h: f64, |         h: f64, | ||||||
|     ) -> (SVector<f64, D>, Option<f64>, Option<Vec<SVector<f64, D>>>) { |     ) -> (SVector<f64, D>, Option<f64>, Option<Vec<SVector<f64, D>>>) { | ||||||
|         let mut k: Vec<SVector<f64, D>> = vec![SVector::<f64, D>::zeros(); Self::STAGES]; |         let mut k: Vec<SVector<f64, D>> = vec![SVector::<f64, D>::zeros(); Self::STAGES]; | ||||||
|         let mut next_y = ode.y.clone(); |         let mut next_y = ode.y; | ||||||
|         let mut err = SVector::<f64, D>::zeros(); |         let mut err = SVector::<f64, D>::zeros(); | ||||||
|         let mut rcont5 = SVector::<f64, D>::zeros(); |         let mut rcont5 = SVector::<f64, D>::zeros(); | ||||||
|         // Do the first of the summations |         // Do the first of the summations | ||||||
| @@ -106,8 +115,8 @@ where | |||||||
|         for i in 1..Self::STAGES { |         for i in 1..Self::STAGES { | ||||||
|             // Compute the ks |             // Compute the ks | ||||||
|             let mut y_term = SVector::<f64, D>::zeros(); |             let mut y_term = SVector::<f64, D>::zeros(); | ||||||
|             for j in 0..i { |             for (j, item) in k.iter().enumerate().take(i) { | ||||||
|                 y_term += k[j] * Self::A[(i * (i - 1)) / 2 + j]; |                 y_term += item * Self::A[(i * (i - 1)) / 2 + j]; | ||||||
|             } |             } | ||||||
|             k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h, &ode.params); |             k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h, &ode.params); | ||||||
|  |  | ||||||
| @@ -119,7 +128,7 @@ where | |||||||
|         let rcont2 = next_y - ode.y; |         let rcont2 = next_y - ode.y; | ||||||
|         let rcont3 = h * k[0] - rcont2; |         let rcont3 = h * k[0] - rcont2; | ||||||
|         let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3; |         let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3; | ||||||
|         let tol = SVector::<f64, D>::repeat(self.a_tol) + ode.y * self.r_tol; |         let tol = self.a_tol + ode.y * self.r_tol; | ||||||
|         let rcont = vec![rcont1, rcont2, rcont3, rcont4, rcont5]; |         let rcont = vec![rcont1, rcont2, rcont3, rcont4, rcont5]; | ||||||
|         (next_y, Some((err.component_div(&tol)).norm()), Some(rcont)) |         (next_y, Some((err.component_div(&tol)).norm()), Some(rcont)) | ||||||
|     } |     } | ||||||
| @@ -127,7 +136,7 @@ where | |||||||
|         &self, |         &self, | ||||||
|         t_start: f64, |         t_start: f64, | ||||||
|         t_end: f64, |         t_end: f64, | ||||||
|         dense: &Vec<SVector<f64, D>>, |         dense: &[SVector<f64, D>], | ||||||
|         t: f64, |         t: f64, | ||||||
|     ) -> SVector<f64, D> { |     ) -> SVector<f64, D> { | ||||||
|         let s = (t - t_start) / (t_end - t_start); |         let s = (t - t_start) / (t_end - t_start); | ||||||
|   | |||||||
| @@ -22,7 +22,7 @@ pub trait Integrator<const D: usize> { | |||||||
|         &self, |         &self, | ||||||
|         t_start: f64, |         t_start: f64, | ||||||
|         t_end: f64, |         t_end: f64, | ||||||
|         dense: &Vec<SVector<f64, D>>, |         dense: &[SVector<f64, D>], | ||||||
|         t: f64, |         t: f64, | ||||||
|     ) -> SVector<f64, D>; |     ) -> SVector<f64, D>; | ||||||
| } | } | ||||||
| @@ -44,7 +44,7 @@ mod tests { | |||||||
|         let y0 = Vector3::new(1.0, 1.0, 1.0); |         let y0 = Vector3::new(1.0, 1.0, 1.0); | ||||||
|         let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ()); |         let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ()); | ||||||
|  |  | ||||||
|         let dp45 = DormandPrince45::new(1e-12_f64, 1e-4_f64); |         let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-4); | ||||||
|  |  | ||||||
|         // Test that y'(t) = y(t) solves to y(t) = e^t for rkf54 |         // Test that y'(t) = y(t) solves to y(t) = e^t for rkf54 | ||||||
|         // and also that the error seems reasonable |         // and also that the error seems reasonable | ||||||
|   | |||||||
							
								
								
									
										34
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								src/lib.rs
									
									
									
									
									
								
							| @@ -18,7 +18,7 @@ pub mod prelude { | |||||||
| mod tests { | mod tests { | ||||||
|     use crate::prelude::*; |     use crate::prelude::*; | ||||||
|     use approx::assert_relative_eq; |     use approx::assert_relative_eq; | ||||||
|     use nalgebra::{Vector2, Vector6}; |     use nalgebra::{Vector1, Vector2, Vector6}; | ||||||
|     use std::f64::consts::PI; |     use std::f64::consts::PI; | ||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
| @@ -38,7 +38,7 @@ mod tests { | |||||||
|  |  | ||||||
|         // Set up the problem (ODE, Integrator, Controller, and Callbacks) |         // Set up the problem (ODE, Integrator, Controller, and Callbacks) | ||||||
|         let ode = ODE::new(&derivative, 0.0, 6.3, y0, params); |         let ode = ODE::new(&derivative, 0.0, 6.3, y0, params); | ||||||
|         let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64); |         let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6); | ||||||
|         let controller = PIController::default(); |         let controller = PIController::default(); | ||||||
|  |  | ||||||
|         let value_too_high = Callback { |         let value_too_high = Callback { | ||||||
| @@ -54,6 +54,32 @@ mod tests { | |||||||
|         let _interpolated_answer = solution.interpolate(4.4); |         let _interpolated_answer = solution.interpolate(4.4); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     #[test] | ||||||
|  |     fn test_correctness() { | ||||||
|  |         // Define the system (parameters, derivative, and initial state) | ||||||
|  |         type Params = (); | ||||||
|  |         let params = (); | ||||||
|  |  | ||||||
|  |         fn derivative(_t: f64, y: Vector1<f64>, _p: &Params) -> Vector1<f64> { | ||||||
|  |             Vector1::new(5.0 * y[0] - 3.0) | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         let y0 = Vector1::new(1.0); | ||||||
|  |  | ||||||
|  |         // Set up the problem (ODE, Integrator, Controller, and Callbacks) | ||||||
|  |         let ode = ODE::new(&derivative, 2.0, 3.0, y0, params); | ||||||
|  |         let dp45 = DormandPrince45::new(); | ||||||
|  |         let controller = PIController::default(); | ||||||
|  |  | ||||||
|  |         // Solve the problem | ||||||
|  |         let mut problem = Problem::new(ode, dp45, controller); | ||||||
|  |         let solution = problem.solve(); | ||||||
|  |         for (time, state) in solution.times.iter().zip(solution.states.iter()) { | ||||||
|  |             let exact = 0.4 * (5.0 * (time - 2.0)).exp() + 0.6; | ||||||
|  |             assert_relative_eq!(state[0], exact, max_relative = 1e-7); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
|     fn test_orbit() { |     fn test_orbit() { | ||||||
|         // Calculate one period |         // Calculate one period | ||||||
| @@ -79,11 +105,9 @@ mod tests { | |||||||
|  |  | ||||||
|         // Integrate |         // Integrate | ||||||
|         let ode = ODE::new(&derivative, 0.0, 10.0 * period, y0, params); |         let ode = ODE::new(&derivative, 0.0, 10.0 * period, y0, params); | ||||||
|         let dp45 = DormandPrince45::new(1e-12_f64, 1e-12_f64); |         let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-12); | ||||||
|         let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); |         let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); | ||||||
|  |  | ||||||
|         let mut problem = Problem::new(ode, dp45, controller); |         let mut problem = Problem::new(ode, dp45, controller); | ||||||
|  |  | ||||||
|         let solution = problem.solve(); |         let solution = problem.solve(); | ||||||
|  |  | ||||||
|         assert_relative_eq!( |         assert_relative_eq!( | ||||||
|   | |||||||
| @@ -1,10 +1,12 @@ | |||||||
| use nalgebra::SVector; | use nalgebra::SVector; | ||||||
|  |  | ||||||
|  | type ProblemFunction<'a, const D: usize, P> = &'a dyn Fn(f64, SVector<f64, D>, &P) -> SVector<f64, D>; | ||||||
|  |  | ||||||
| /// The basic ODE object that will be passed around. The type (T) and the size (D) will be | /// The basic ODE object that will be passed around. The type (T) and the size (D) will be | ||||||
| /// determined upon creation of the object | /// determined upon creation of the object | ||||||
| #[derive(Clone, Copy)] | #[derive(Clone, Copy)] | ||||||
| pub struct ODE<'a, const D: usize, P> { | pub struct ODE<'a, const D: usize, P> { | ||||||
|     pub f: &'a dyn Fn(f64, SVector<f64, D>, &P) -> SVector<f64, D>, |     pub f: ProblemFunction<'a, D, P>, | ||||||
|     pub y: SVector<f64, D>, |     pub y: SVector<f64, D>, | ||||||
|     pub t: f64, |     pub t: f64, | ||||||
|     pub params: P, |     pub params: P, | ||||||
| @@ -16,7 +18,7 @@ pub struct ODE<'a, const D: usize, P> { | |||||||
|  |  | ||||||
| impl<'a, const D: usize, P> ODE<'a, D, P> { | impl<'a, const D: usize, P> ODE<'a, D, P> { | ||||||
|     pub fn new( |     pub fn new( | ||||||
|         f: &'a (dyn Fn(f64, SVector<f64, D>, &P) -> SVector<f64, D>), |         f: ProblemFunction<'a, D, P>, | ||||||
|         t0: f64, |         t0: f64, | ||||||
|         t_end: f64, |         t_end: f64, | ||||||
|         y0: SVector<f64, D>, |         y0: SVector<f64, D>, | ||||||
|   | |||||||
							
								
								
									
										105
									
								
								src/problem.rs
									
									
									
									
									
								
							
							
						
						
									
										105
									
								
								src/problem.rs
									
									
									
									
									
								
							| @@ -1,8 +1,8 @@ | |||||||
| use nalgebra::SVector; | use nalgebra::SVector; | ||||||
| use roots::{find_root_brent, DebugConvergency}; | use roots::{find_root_brent, SimpleConvergency}; | ||||||
|  |  | ||||||
| use super::callback::Callback; | use super::callback::Callback; | ||||||
| use super::controller::{Controller, PIController}; | use super::controller::{Controller, PIController, TryStep}; | ||||||
| use super::integrator::Integrator; | use super::integrator::Integrator; | ||||||
| use super::ode::ODE; | use super::ode::ODE; | ||||||
|  |  | ||||||
| @@ -23,48 +23,64 @@ where | |||||||
| { | { | ||||||
|     pub fn new(ode: ODE<'a, D, P>, integrator: S, controller: PIController) -> Self { |     pub fn new(ode: ODE<'a, D, P>, integrator: S, controller: PIController) -> Self { | ||||||
|         Problem { |         Problem { | ||||||
|             ode: ode, |             ode, | ||||||
|             integrator: integrator, |             integrator, | ||||||
|             controller: controller, |             controller, | ||||||
|             callbacks: Vec::new(), |             callbacks: Vec::new(), | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     pub fn solve(&mut self) -> Solution<S, D> { |     pub fn solve(&mut self) -> Solution<S, D> { | ||||||
|         let mut convergency = DebugConvergency::new(1e-12, 50); |         let mut convergency = SimpleConvergency { | ||||||
|  |             eps: 1e-12, | ||||||
|  |             max_iter: 1000, | ||||||
|  |         }; | ||||||
|         let mut times: Vec<f64> = vec![self.ode.t]; |         let mut times: Vec<f64> = vec![self.ode.t]; | ||||||
|         let mut states: Vec<SVector<f64, D>> = vec![self.ode.y]; |         let mut states: Vec<SVector<f64, D>> = vec![self.ode.y]; | ||||||
|         let mut dense_coefficients: Vec<Vec<SVector<f64, D>>> = Vec::new(); |         let mut dense_coefficients: Vec<Vec<SVector<f64, D>>> = Vec::new(); | ||||||
|         let mut step: f64 = self.controller.old_h; |  | ||||||
|         let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0); |  | ||||||
|         while self.ode.t < self.ode.t_end { |         while self.ode.t < self.ode.t_end { | ||||||
|             let mut dense_option: Option<Vec<SVector<f64, D>>> = None; |             if self.ode.t + self.controller.next_step_guess.extract() > self.ode.t_end { | ||||||
|             if S::ADAPTIVE { |                 // If the next step would go past the end, then just set it to the end | ||||||
|  |                 self.controller.next_step_guess = TryStep::NotYetAccepted( | ||||||
|  |                     self.ode.t_end - self.ode.t, | ||||||
|  |                 ); | ||||||
|  |             } | ||||||
|  |             let (mut new_y, mut curr_step, mut dense_option) = if S::ADAPTIVE { | ||||||
|  |                 // First, we try stepping with the "next step guess" to get the error | ||||||
|  |                 let (mut trial_y, mut err_option, mut dense_option) = | ||||||
|  |                     self.integrator.step(&self.ode, self.controller.next_step_guess.extract()); | ||||||
|                 let mut err = err_option.unwrap(); |                 let mut err = err_option.unwrap(); | ||||||
|                 let mut accepted: bool = false; |                 // Then we determine whether we need to reduce the step size or not | ||||||
|                 while !accepted { |                 // If successful, we get the next step guess | ||||||
|                     // Try a step and if that isn't acceptable, then change the step until it is |                 let initial_guess = self.controller.next_step_guess.extract(); | ||||||
|                     (accepted, step) = <PIController as Controller<D>>::determine_step( |                 let mut next_step_guess = <PIController as Controller<D>>::determine_step( | ||||||
|  |                     &mut self.controller, | ||||||
|  |                     initial_guess, | ||||||
|  |                     err, | ||||||
|  |                 ); | ||||||
|  |                 while !next_step_guess.is_accepted() { | ||||||
|  |                     // If that step isn't acceptable, then change the step until it is | ||||||
|  |                     (trial_y, err_option, dense_option) = | ||||||
|  |                         self.integrator.step(&self.ode, next_step_guess.extract()); | ||||||
|  |                     next_step_guess = <PIController as Controller<D>>::determine_step( | ||||||
|                         &mut self.controller, |                         &mut self.controller, | ||||||
|                         step, |                         next_step_guess.extract(), | ||||||
|                         err, |                         err, | ||||||
|                     ); |                     ); | ||||||
|                     (new_y, err_option, dense_option) = self.integrator.step(&self.ode, step); |  | ||||||
|                     err = err_option.unwrap(); |                     err = err_option.unwrap(); | ||||||
|                 } |                 } | ||||||
|                 self.controller.old_h = step; |                 // So at this point we can safely assume we have an accepted step | ||||||
|                 self.controller.h_max = self |                 self.controller.next_step_guess = next_step_guess.reset().unwrap(); | ||||||
|                     .controller |                 (trial_y, next_step_guess.extract(), dense_option) | ||||||
|                     .h_max |  | ||||||
|                     .min(self.ode.t_end - self.ode.t - step); |  | ||||||
|             } else { |             } else { | ||||||
|                 // If fixed time step just step forward one step |                 // If fixed time step just step forward one step | ||||||
|                 (new_y, _, dense_option) = self.integrator.step(&self.ode, step); |                 let (trial_y, _, dense_option) = self.integrator.step(&self.ode, self.controller.next_step_guess.extract()); | ||||||
|             } |                 (trial_y, self.controller.next_step_guess.extract(), dense_option) | ||||||
|             if self.callbacks.len() > 0 { |             }; | ||||||
|  |             if !self.callbacks.is_empty() { | ||||||
|                 // Check for events occurring |                 // Check for events occurring | ||||||
|                 for callback in &self.callbacks { |                 for callback in &self.callbacks { | ||||||
|                     if (callback.event)(self.ode.t, self.ode.y, &self.ode.params) |                     if (callback.event)(self.ode.t, self.ode.y, &self.ode.params) | ||||||
|                         * (callback.event)(self.ode.t + step, new_y, &self.ode.params) |                         * (callback.event)(self.ode.t + curr_step, new_y, &self.ode.params) | ||||||
|                         < 0.0 |                         < 0.0 | ||||||
|                     { |                     { | ||||||
|                         // If the event crossed zero, then find the root |                         // If the event crossed zero, then find the root | ||||||
| @@ -72,15 +88,15 @@ where | |||||||
|                             let test_y = self.integrator.step(&self.ode, test_t).0; |                             let test_y = self.integrator.step(&self.ode, test_t).0; | ||||||
|                             (callback.event)(self.ode.t + test_t, test_y, &self.ode.params) |                             (callback.event)(self.ode.t + test_t, test_y, &self.ode.params) | ||||||
|                         }; |                         }; | ||||||
|                         let root = find_root_brent(0.0, step, &f, &mut convergency).unwrap(); |                         let root = find_root_brent(0.0, curr_step, &f, &mut convergency).unwrap(); | ||||||
|                         step = root; |                         curr_step = root; | ||||||
|                         (new_y, _, dense_option) = self.integrator.step(&self.ode, step); |                         (new_y, _, dense_option) = self.integrator.step(&self.ode, curr_step); | ||||||
|                         (callback.effect)(&mut self.ode); |                         (callback.effect)(&mut self.ode); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             self.ode.y = new_y; |             self.ode.y = new_y; | ||||||
|             self.ode.t += step; |             self.ode.t += curr_step; | ||||||
|             times.push(self.ode.t); |             times.push(self.ode.t); | ||||||
|             states.push(self.ode.y); |             states.push(self.ode.y); | ||||||
|             // TODO: Implement third order interpolation for non-dense algorithms |             // TODO: Implement third order interpolation for non-dense algorithms | ||||||
| @@ -134,19 +150,16 @@ where | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         // Then find the two t values closest to the desired t |         // Then find the two t values closest to the desired t | ||||||
|         let mut end_index: usize = 0; |         match times.binary_search_by(|x| x.total_cmp(&t)) { | ||||||
|         for (i, time) in self.times.iter().enumerate() { |             Ok(index) => self.states[index], | ||||||
|             if time > &t { |             Err(end_index) => { | ||||||
|                 end_index = i; |                 // Then send that to the integrator | ||||||
|                 break; |                 let t_start = times[end_index - 1]; | ||||||
|  |                 let t_end = times[end_index]; | ||||||
|  |                 self.integrator | ||||||
|  |                     .interpolate(t_start, t_end, &self.dense[end_index - 1], t) | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // Then send that to the integrator |  | ||||||
|         let t_start = times[end_index - 1]; |  | ||||||
|         let t_end = times[end_index]; |  | ||||||
|         self.integrator |  | ||||||
|             .interpolate(t_start, t_end, &self.dense[end_index - 1], t) |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -168,7 +181,7 @@ mod tests { | |||||||
|         let y0 = Vector3::new(1.0, 1.0, 1.0); |         let y0 = Vector3::new(1.0, 1.0, 1.0); | ||||||
|  |  | ||||||
|         let ode = ODE::new(&derivative, 0.0, 1.0, y0, ()); |         let ode = ODE::new(&derivative, 0.0, 1.0, y0, ()); | ||||||
|         let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); |         let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5); | ||||||
|         let controller = PIController::default(); |         let controller = PIController::default(); | ||||||
|  |  | ||||||
|         let mut problem = Problem::new(ode, dp45, controller); |         let mut problem = Problem::new(ode, dp45, controller); | ||||||
| @@ -192,7 +205,7 @@ mod tests { | |||||||
|         let y0 = Vector3::new(1.0, 1.0, 1.0); |         let y0 = Vector3::new(1.0, 1.0, 1.0); | ||||||
|  |  | ||||||
|         let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); |         let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); | ||||||
|         let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); |         let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5); | ||||||
|         let controller = PIController::default(); |         let controller = PIController::default(); | ||||||
|  |  | ||||||
|         let value_too_high = Callback { |         let value_too_high = Callback { | ||||||
| @@ -203,7 +216,11 @@ mod tests { | |||||||
|         let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high); |         let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high); | ||||||
|         let solution = problem.solve(); |         let solution = problem.solve(); | ||||||
|  |  | ||||||
|         assert!(solution.states.last().unwrap()[0] == 10.0); |         assert_relative_eq!( | ||||||
|  |             solution.states.last().unwrap()[0], | ||||||
|  |             10.0, | ||||||
|  |             max_relative = 1e-11 | ||||||
|  |         ); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
| @@ -215,7 +232,7 @@ mod tests { | |||||||
|         let y0 = Vector3::new(1.0, 1.0, 1.0); |         let y0 = Vector3::new(1.0, 1.0, 1.0); | ||||||
|  |  | ||||||
|         let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); |         let ode = ODE::new(&derivative, 0.0, 10.0, y0, ()); | ||||||
|         let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64); |         let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6); | ||||||
|         let controller = PIController::default(); |         let controller = PIController::default(); | ||||||
|  |  | ||||||
|         let mut problem = Problem::new(ode, dp45, controller); |         let mut problem = Problem::new(ode, dp45, controller); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user