Added some benchmarking and small performance improvements
This commit is contained in:
@@ -13,3 +13,12 @@ roots = "0.0.8"
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
approx = "0.5"
|
approx = "0.5"
|
||||||
|
criterion = "0.7.0"
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "simple_1d"
|
||||||
|
harness = false
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "orbit"
|
||||||
|
harness = false
|
||||||
|
|||||||
40
benches/orbit.rs
Normal file
40
benches/orbit.rs
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
use criterion::{criterion_group, criterion_main, Criterion};
|
||||||
|
|
||||||
|
use differential_equations::prelude::*;
|
||||||
|
use nalgebra::Vector6;
|
||||||
|
|
||||||
|
fn bench_orbit(c: &mut Criterion) {
|
||||||
|
let mu = 3.98600441500000e14;
|
||||||
|
|
||||||
|
// Set up the system
|
||||||
|
type Params = (f64,);
|
||||||
|
let params = (mu,);
|
||||||
|
fn derivative(_t: f64, state: Vector6<f64>, p: &Params) -> Vector6<f64> {
|
||||||
|
let acc = -(p.0 * state.fixed_rows::<3>(0)) / (state.fixed_rows::<3>(0).norm().powi(3));
|
||||||
|
Vector6::new(state[3], state[4], state[5], acc[0], acc[1], acc[2])
|
||||||
|
}
|
||||||
|
let y0 = Vector6::new(
|
||||||
|
4.263868426884883e6,
|
||||||
|
5.146189057155391e6,
|
||||||
|
1.1310208421331816e6,
|
||||||
|
-5923.454461876975,
|
||||||
|
4496.802639690076,
|
||||||
|
1870.3893008991558,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Integrate
|
||||||
|
let ode = ODE::new(&derivative, 0.0, 86400.0, y0, params);
|
||||||
|
let dp45 = DormandPrince45::new(1e-8_f64, 1e-8_f64);
|
||||||
|
let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01);
|
||||||
|
|
||||||
|
c.bench_function("bench_orbit", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
std::hint::black_box({
|
||||||
|
Problem::new(ode, dp45, controller).solve();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, bench_orbit);
|
||||||
|
criterion_main!(benches);
|
||||||
56
benches/simple_1d.rs
Normal file
56
benches/simple_1d.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
use criterion::{criterion_group, criterion_main, Criterion};
|
||||||
|
|
||||||
|
use differential_equations::prelude::*;
|
||||||
|
use nalgebra::Vector1;
|
||||||
|
|
||||||
|
fn bench_simple_1d(c: &mut Criterion) {
|
||||||
|
type Params = (f64,);
|
||||||
|
let params = (0.1,);
|
||||||
|
|
||||||
|
fn derivative(_t: f64, y: Vector1<f64>, p: &Params) -> Vector1<f64> {
|
||||||
|
Vector1::new(-p.0 * y[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
let y0 = Vector1::new(1.0);
|
||||||
|
|
||||||
|
// Set up the problem (ODE, Integrator, Controller, and Callbacks)
|
||||||
|
let ode = ODE::new(&derivative, 0.0, 10.0, y0, params);
|
||||||
|
let dp45 = DormandPrince45::new(1e-1_f64, 1e-6_f64);
|
||||||
|
let controller = PIController::default();
|
||||||
|
|
||||||
|
c.bench_function("bench_simple_1d", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
std::hint::black_box({
|
||||||
|
Problem::new(ode, dp45, controller).solve();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bench_interpolation_1d(c: &mut Criterion) {
|
||||||
|
type Params = (f64,);
|
||||||
|
let params = (0.1,);
|
||||||
|
|
||||||
|
fn derivative(_t: f64, y: Vector1<f64>, p: &Params) -> Vector1<f64> {
|
||||||
|
Vector1::new(-p.0 * y[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
let y0 = Vector1::new(1.0);
|
||||||
|
|
||||||
|
// Set up the problem (ODE, Integrator, Controller, and Callbacks)
|
||||||
|
let ode = ODE::new(&derivative, 0.0, 10.0, y0, params);
|
||||||
|
let dp45 = DormandPrince45::new(1e-1_f64, 1e-6_f64);
|
||||||
|
let controller = PIController::default();
|
||||||
|
|
||||||
|
c.bench_function("bench_interpolation_1d", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
std::hint::black_box({
|
||||||
|
let solution = Problem::new(ode, dp45, controller).solve();
|
||||||
|
let _ = (0..100).map(|t| solution.interpolate(t as f64 * 0.1)[0]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, bench_simple_1d, bench_interpolation_1d,);
|
||||||
|
criterion_main!(benches);
|
||||||
28
src/lib.rs
28
src/lib.rs
@@ -18,7 +18,7 @@ pub mod prelude {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use crate::prelude::*;
|
use crate::prelude::*;
|
||||||
use approx::assert_relative_eq;
|
use approx::assert_relative_eq;
|
||||||
use nalgebra::{Vector2, Vector6};
|
use nalgebra::{Vector1, Vector2, Vector6};
|
||||||
use std::f64::consts::PI;
|
use std::f64::consts::PI;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -54,6 +54,32 @@ mod tests {
|
|||||||
let _interpolated_answer = solution.interpolate(4.4);
|
let _interpolated_answer = solution.interpolate(4.4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_correctness() {
|
||||||
|
// Define the system (parameters, derivative, and initial state)
|
||||||
|
type Params = ();
|
||||||
|
let params = ();
|
||||||
|
|
||||||
|
fn derivative(_t: f64, y: Vector1<f64>, _p: &Params) -> Vector1<f64> {
|
||||||
|
Vector1::new(5.0 * y[0] - 3.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
let y0 = Vector1::new(1.0);
|
||||||
|
|
||||||
|
// Set up the problem (ODE, Integrator, Controller, and Callbacks)
|
||||||
|
let ode = ODE::new(&derivative, 2.0, 3.0, y0, params);
|
||||||
|
let dp45 = DormandPrince45::new(1e-8_f64, 1e-8_f64);
|
||||||
|
let controller = PIController::default();
|
||||||
|
|
||||||
|
// Solve the problem
|
||||||
|
let mut problem = Problem::new(ode, dp45, controller);
|
||||||
|
let solution = problem.solve();
|
||||||
|
for (time, state) in solution.times.iter().zip(solution.states.iter()) {
|
||||||
|
let exact = 0.4 * (5.0 * (time - 2.0)).exp() + 0.6;
|
||||||
|
assert_relative_eq!(state[0], exact, max_relative = 1e-7);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_orbit() {
|
fn test_orbit() {
|
||||||
// Calculate one period
|
// Calculate one period
|
||||||
|
|||||||
@@ -134,19 +134,16 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Then find the two t values closest to the desired t
|
// Then find the two t values closest to the desired t
|
||||||
let mut end_index: usize = 0;
|
match times.binary_search_by(|x| x.total_cmp(&t)) {
|
||||||
for (i, time) in self.times.iter().enumerate() {
|
Ok(index) => self.states[index],
|
||||||
if time > &t {
|
Err(end_index) => {
|
||||||
end_index = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then send that to the integrator
|
// Then send that to the integrator
|
||||||
let t_start = times[end_index - 1];
|
let t_start = times[end_index - 1];
|
||||||
let t_end = times[end_index];
|
let t_end = times[end_index];
|
||||||
self.integrator
|
self.integrator
|
||||||
.interpolate(t_start, t_end, &self.dense[end_index - 1], t)
|
.interpolate(t_start, t_end, &self.dense[end_index - 1], t)
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user