diff --git a/Cargo.toml b/Cargo.toml index 6b715e3..d263ec9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,12 @@ roots = "0.0.8" [dev-dependencies] approx = "0.5" +criterion = "0.7.0" + +[[bench]] +name = "simple_1d" +harness = false + +[[bench]] +name = "orbit" +harness = false diff --git a/benches/orbit.rs b/benches/orbit.rs new file mode 100644 index 0000000..74e5201 --- /dev/null +++ b/benches/orbit.rs @@ -0,0 +1,40 @@ +use criterion::{criterion_group, criterion_main, Criterion}; + +use differential_equations::prelude::*; +use nalgebra::Vector6; + +fn bench_orbit(c: &mut Criterion) { + let mu = 3.98600441500000e14; + + // Set up the system + type Params = (f64,); + let params = (mu,); + fn derivative(_t: f64, state: Vector6, p: &Params) -> Vector6 { + let acc = -(p.0 * state.fixed_rows::<3>(0)) / (state.fixed_rows::<3>(0).norm().powi(3)); + Vector6::new(state[3], state[4], state[5], acc[0], acc[1], acc[2]) + } + let y0 = Vector6::new( + 4.263868426884883e6, + 5.146189057155391e6, + 1.1310208421331816e6, + -5923.454461876975, + 4496.802639690076, + 1870.3893008991558, + ); + + // Integrate + let ode = ODE::new(&derivative, 0.0, 86400.0, y0, params); + let dp45 = DormandPrince45::new(1e-8_f64, 1e-8_f64); + let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01); + + c.bench_function("bench_orbit", |b| { + b.iter(|| { + std::hint::black_box({ + Problem::new(ode, dp45, controller).solve(); + }); + }); + }); +} + +criterion_group!(benches, bench_orbit); +criterion_main!(benches); diff --git a/benches/simple_1d.rs b/benches/simple_1d.rs new file mode 100644 index 0000000..b247d3d --- /dev/null +++ b/benches/simple_1d.rs @@ -0,0 +1,56 @@ +use criterion::{criterion_group, criterion_main, Criterion}; + +use differential_equations::prelude::*; +use nalgebra::Vector1; + +fn bench_simple_1d(c: &mut Criterion) { + type Params = (f64,); + let params = (0.1,); + + fn derivative(_t: f64, y: Vector1, p: &Params) -> Vector1 { + Vector1::new(-p.0 * y[0]) + } + + let y0 = Vector1::new(1.0); + + // Set up the problem (ODE, Integrator, Controller, and Callbacks) + let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); + let dp45 = DormandPrince45::new(1e-1_f64, 1e-6_f64); + let controller = PIController::default(); + + c.bench_function("bench_simple_1d", |b| { + b.iter(|| { + std::hint::black_box({ + Problem::new(ode, dp45, controller).solve(); + }); + }); + }); +} + +fn bench_interpolation_1d(c: &mut Criterion) { + type Params = (f64,); + let params = (0.1,); + + fn derivative(_t: f64, y: Vector1, p: &Params) -> Vector1 { + Vector1::new(-p.0 * y[0]) + } + + let y0 = Vector1::new(1.0); + + // Set up the problem (ODE, Integrator, Controller, and Callbacks) + let ode = ODE::new(&derivative, 0.0, 10.0, y0, params); + let dp45 = DormandPrince45::new(1e-1_f64, 1e-6_f64); + let controller = PIController::default(); + + c.bench_function("bench_interpolation_1d", |b| { + b.iter(|| { + std::hint::black_box({ + let solution = Problem::new(ode, dp45, controller).solve(); + let _ = (0..100).map(|t| solution.interpolate(t as f64 * 0.1)[0]); + }); + }); + }); +} + +criterion_group!(benches, bench_simple_1d, bench_interpolation_1d,); +criterion_main!(benches); diff --git a/src/lib.rs b/src/lib.rs index da32822..608ef17 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ pub mod prelude { mod tests { use crate::prelude::*; use approx::assert_relative_eq; - use nalgebra::{Vector2, Vector6}; + use nalgebra::{Vector1, Vector2, Vector6}; use std::f64::consts::PI; #[test] @@ -54,6 +54,32 @@ mod tests { let _interpolated_answer = solution.interpolate(4.4); } + #[test] + fn test_correctness() { + // Define the system (parameters, derivative, and initial state) + type Params = (); + let params = (); + + fn derivative(_t: f64, y: Vector1, _p: &Params) -> Vector1 { + Vector1::new(5.0 * y[0] - 3.0) + } + + let y0 = Vector1::new(1.0); + + // Set up the problem (ODE, Integrator, Controller, and Callbacks) + let ode = ODE::new(&derivative, 2.0, 3.0, y0, params); + let dp45 = DormandPrince45::new(1e-8_f64, 1e-8_f64); + let controller = PIController::default(); + + // Solve the problem + let mut problem = Problem::new(ode, dp45, controller); + let solution = problem.solve(); + for (time, state) in solution.times.iter().zip(solution.states.iter()) { + let exact = 0.4 * (5.0 * (time - 2.0)).exp() + 0.6; + assert_relative_eq!(state[0], exact, max_relative = 1e-7); + } + } + #[test] fn test_orbit() { // Calculate one period diff --git a/src/problem.rs b/src/problem.rs index 8e72f3f..f96c37d 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -134,19 +134,16 @@ where } // Then find the two t values closest to the desired t - let mut end_index: usize = 0; - for (i, time) in self.times.iter().enumerate() { - if time > &t { - end_index = i; - break; - } + match times.binary_search_by(|x| x.total_cmp(&t)) { + Ok(index) => self.states[index], + Err(end_index) => { + // Then send that to the integrator + let t_start = times[end_index - 1]; + let t_end = times[end_index]; + self.integrator + .interpolate(t_start, t_end, &self.dense[end_index - 1], t) + }, } - - // Then send that to the integrator - let t_start = times[end_index - 1]; - let t_end = times[end_index]; - self.integrator - .interpolate(t_start, t_end, &self.dense[end_index - 1], t) } }