Compare commits
10 Commits
3a15323a9c
...
b42f3a3e77
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b42f3a3e77 | ||
|
|
69d2fe4336 | ||
|
|
44bb3e5ac1 | ||
|
|
0dfed1cd06 | ||
|
|
2659d78582 | ||
|
|
9075dac669 | ||
|
|
e27ef0a07c | ||
|
|
0cfd4f1f5d | ||
|
|
76089fa012 | ||
|
|
5d0a7d6e84 |
22
Cargo.toml
22
Cargo.toml
@@ -1,15 +1,29 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "differential_equations"
|
name = "ordinary-diffeq"
|
||||||
version = "0.2.1"
|
version = "0.2.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
authors = ["Connor Johnstone"]
|
||||||
|
description = "A library for solving differential equations based on the DifferentialEquations.jl julia library."
|
||||||
|
readme = "readme.md"
|
||||||
|
repository = "https://gitlab.rcjohnstone.com/connor/differential-equations"
|
||||||
|
license = "MIT"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
nalgebra = { version = "0.32", features = ["serde-serialize"] }
|
nalgebra = { version = "0.34", features = ["serde-serialize"] }
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.19"
|
||||||
roots = "0.0.8"
|
roots = "0.0.8"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
approx = "0.5"
|
approx = "0.5"
|
||||||
|
criterion = "0.7.0"
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "simple_1d"
|
||||||
|
harness = false
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "orbit"
|
||||||
|
harness = false
|
||||||
|
|||||||
40
benches/orbit.rs
Normal file
40
benches/orbit.rs
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
use criterion::{criterion_group, criterion_main, Criterion};
|
||||||
|
|
||||||
|
use ordinary_diffeq::prelude::*;
|
||||||
|
use nalgebra::Vector6;
|
||||||
|
|
||||||
|
fn bench_orbit(c: &mut Criterion) {
|
||||||
|
let mu = 3.98600441500000e14;
|
||||||
|
|
||||||
|
// Set up the system
|
||||||
|
type Params = (f64,);
|
||||||
|
let params = (mu,);
|
||||||
|
fn derivative(_t: f64, state: Vector6<f64>, p: &Params) -> Vector6<f64> {
|
||||||
|
let acc = -(p.0 * state.fixed_rows::<3>(0)) / (state.fixed_rows::<3>(0).norm().powi(3));
|
||||||
|
Vector6::new(state[3], state[4], state[5], acc[0], acc[1], acc[2])
|
||||||
|
}
|
||||||
|
let y0 = Vector6::new(
|
||||||
|
4.263868426884883e6,
|
||||||
|
5.146189057155391e6,
|
||||||
|
1.1310208421331816e6,
|
||||||
|
-5923.454461876975,
|
||||||
|
4496.802639690076,
|
||||||
|
1870.3893008991558,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Integrate
|
||||||
|
let ode = ODE::new(&derivative, 0.0, 86400.0, y0, params);
|
||||||
|
let dp45 = DormandPrince45::new();
|
||||||
|
let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01);
|
||||||
|
|
||||||
|
c.bench_function("bench_orbit", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
std::hint::black_box({
|
||||||
|
Problem::new(ode, dp45, controller).solve();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, bench_orbit);
|
||||||
|
criterion_main!(benches);
|
||||||
56
benches/simple_1d.rs
Normal file
56
benches/simple_1d.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
use criterion::{criterion_group, criterion_main, Criterion};
|
||||||
|
|
||||||
|
use ordinary_diffeq::prelude::*;
|
||||||
|
use nalgebra::Vector1;
|
||||||
|
|
||||||
|
fn bench_simple_1d(c: &mut Criterion) {
|
||||||
|
type Params = (f64,);
|
||||||
|
let params = (0.1,);
|
||||||
|
|
||||||
|
fn derivative(_t: f64, y: Vector1<f64>, p: &Params) -> Vector1<f64> {
|
||||||
|
Vector1::new(-p.0 * y[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
let y0 = Vector1::new(1.0);
|
||||||
|
|
||||||
|
// Set up the problem (ODE, Integrator, Controller, and Callbacks)
|
||||||
|
let ode = ODE::new(&derivative, 0.0, 10.0, y0, params);
|
||||||
|
let dp45 = DormandPrince45::new().a_tol(1e-6).r_tol(1e-6);
|
||||||
|
let controller = PIController::default();
|
||||||
|
|
||||||
|
c.bench_function("bench_simple_1d", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
std::hint::black_box({
|
||||||
|
Problem::new(ode, dp45, controller).solve();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bench_interpolation_1d(c: &mut Criterion) {
|
||||||
|
type Params = (f64,);
|
||||||
|
let params = (0.1,);
|
||||||
|
|
||||||
|
fn derivative(_t: f64, y: Vector1<f64>, p: &Params) -> Vector1<f64> {
|
||||||
|
Vector1::new(-p.0 * y[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
let y0 = Vector1::new(1.0);
|
||||||
|
|
||||||
|
// Set up the problem (ODE, Integrator, Controller, and Callbacks)
|
||||||
|
let ode = ODE::new(&derivative, 0.0, 10.0, y0, params);
|
||||||
|
let dp45 = DormandPrince45::new().a_tol(1e-6).r_tol(1e-6);
|
||||||
|
let controller = PIController::default();
|
||||||
|
|
||||||
|
c.bench_function("bench_interpolation_1d", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
std::hint::black_box({
|
||||||
|
let solution = Problem::new(ode, dp45, controller).solve();
|
||||||
|
let _ = (0..100).map(|t| solution.interpolate(t as f64 * 0.1)[0]);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, bench_simple_1d, bench_interpolation_1d,);
|
||||||
|
criterion_main!(benches);
|
||||||
@@ -53,7 +53,7 @@ let y0 = Vector2::new(0.0, PI/2.0);
|
|||||||
|
|
||||||
// Set up the problem (ODE, Integrator, Controller, and Callbacks)
|
// Set up the problem (ODE, Integrator, Controller, and Callbacks)
|
||||||
let ode = ODE::new(&derivative, 0.0, 6.3, y0, params);
|
let ode = ODE::new(&derivative, 0.0, 6.3, y0, params);
|
||||||
let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64);
|
let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6);
|
||||||
let controller = PIController::default();
|
let controller = PIController::default();
|
||||||
|
|
||||||
let value_too_high = Callback {
|
let value_too_high = Callback {
|
||||||
|
|||||||
@@ -11,11 +11,11 @@ pub struct Callback<'a, const D: usize, P> {
|
|||||||
pub event: &'a dyn Fn(f64, SVector<f64, D>, &P) -> f64,
|
pub event: &'a dyn Fn(f64, SVector<f64, D>, &P) -> f64,
|
||||||
|
|
||||||
/// The function to change the ODE
|
/// The function to change the ODE
|
||||||
pub effect: &'a dyn Fn(&mut ODE<D, P>) -> (),
|
pub effect: &'a dyn Fn(&mut ODE<D, P>),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A convenience function for stopping the integration
|
/// A convenience function for stopping the integration
|
||||||
pub fn stop<'a, const D: usize, P>(ode: &mut ODE<D, P>) -> () {
|
pub fn stop<const D: usize, P>(ode: &mut ODE<D, P>) {
|
||||||
ode.t_end = ode.t;
|
ode.t_end = ode.t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,31 @@
|
|||||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
|
pub enum TryStep {
|
||||||
|
Accepted(f64, f64),
|
||||||
|
NotYetAccepted(f64),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryStep {
|
||||||
|
pub fn extract(&self) -> f64 {
|
||||||
|
match self {
|
||||||
|
TryStep::Accepted(h, _) => *h,
|
||||||
|
TryStep::NotYetAccepted(h) => *h,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_accepted(&self) -> bool {
|
||||||
|
matches!(self, TryStep::Accepted(_, _))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) -> Result<TryStep, &str> {
|
||||||
|
match self {
|
||||||
|
TryStep::Accepted(_, h) => Ok(TryStep::NotYetAccepted(*h)),
|
||||||
|
TryStep::NotYetAccepted(_) => Err("Cannot reset a NotYetAccepted TryStep"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait Controller<const D: usize> {
|
pub trait Controller<const D: usize> {
|
||||||
fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64);
|
fn determine_step(&mut self, h: f64, err: f64) -> TryStep;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
@@ -11,32 +37,30 @@ pub struct PIController {
|
|||||||
pub factor_old: f64,
|
pub factor_old: f64,
|
||||||
pub h_max: f64,
|
pub h_max: f64,
|
||||||
pub safety_factor: f64,
|
pub safety_factor: f64,
|
||||||
pub old_h: f64,
|
pub next_step_guess: TryStep,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const D: usize> Controller<D> for PIController {
|
impl<const D: usize> Controller<D> for PIController {
|
||||||
/// Determines if the previously run step size and error were valid or not. Either way, it also
|
/// Determines if the previously run step size and error were valid or not. Either way, it also
|
||||||
/// returns what the next step size should be
|
/// returns what the next step size should be
|
||||||
fn determine_step(&mut self, h: f64, err: f64) -> (bool, f64) {
|
fn determine_step(&mut self, prev_step: f64, err: f64) -> TryStep {
|
||||||
let factor_11 = err.powf(self.alpha);
|
let factor_11 = err.powf(self.alpha);
|
||||||
let factor = self.factor_c2.max(
|
let factor = self.factor_c2.max(
|
||||||
self.factor_c1
|
self.factor_c1
|
||||||
.min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor),
|
.min(factor_11 * self.factor_old.powf(-self.beta) / self.safety_factor),
|
||||||
);
|
);
|
||||||
let mut h_new = h / factor;
|
|
||||||
if err <= 1.0 {
|
if err <= 1.0 {
|
||||||
// Accept the stepsize
|
let mut h = prev_step / factor;
|
||||||
|
// Accept the stepsize and provide what the next step size should be
|
||||||
self.factor_old = err.max(1.0e-4);
|
self.factor_old = err.max(1.0e-4);
|
||||||
if h_new.abs() > self.h_max {
|
if h.abs() > self.h_max {
|
||||||
// If the step is too big
|
// If the step goes past the maximum allowed, though, we shrink it
|
||||||
h_new = self.h_max.copysign(h_new);
|
h = self.h_max.copysign(h);
|
||||||
}
|
}
|
||||||
(true, h_new)
|
TryStep::Accepted(prev_step, h)
|
||||||
// (true, h_new)
|
|
||||||
} else {
|
} else {
|
||||||
// Reject the stepsize and propose a smaller one
|
// Reject the stepsize and propose a smaller one for the current step
|
||||||
h_new = h / (self.factor_c1.min(factor_11 / self.safety_factor));
|
TryStep::NotYetAccepted(prev_step / (self.factor_c1.min(factor_11 / self.safety_factor)))
|
||||||
(false, h_new)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -52,17 +76,20 @@ impl PIController {
|
|||||||
initial_h: f64,
|
initial_h: f64,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
alpha: alpha,
|
alpha,
|
||||||
beta: beta,
|
beta,
|
||||||
factor_c1: 1.0 / min_factor,
|
factor_c1: 1.0 / min_factor,
|
||||||
factor_c2: 1.0 / max_factor,
|
factor_c2: 1.0 / max_factor,
|
||||||
factor_old: 1.0e-4,
|
factor_old: 1.0e-4,
|
||||||
h_max: h_max.abs(),
|
h_max: h_max.abs(),
|
||||||
safety_factor: safety_factor,
|
safety_factor,
|
||||||
old_h: initial_h,
|
next_step_guess: TryStep::NotYetAccepted(initial_h),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn default() -> Self {
|
}
|
||||||
|
|
||||||
|
impl Default for PIController {
|
||||||
|
fn default() -> Self {
|
||||||
Self::new(0.17, 0.04, 10.0, 0.2, 100000.0, 0.9, 1e-4)
|
Self::new(0.17, 0.04, 10.0, 0.2, 100000.0, 0.9, 1e-4)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -82,6 +109,6 @@ mod tests {
|
|||||||
assert!(controller.factor_old == 1.0e-4);
|
assert!(controller.factor_old == 1.0e-4);
|
||||||
assert!(controller.h_max == 10.0);
|
assert!(controller.h_max == 10.0);
|
||||||
assert!(controller.safety_factor == 0.9);
|
assert!(controller.safety_factor == 0.9);
|
||||||
assert!(controller.old_h == 1e-4);
|
assert!(controller.next_step_guess == TryStep::NotYetAccepted(1e-4));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ pub trait DormandPrinceIntegrator<'a> {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct DormandPrince45<const D: usize> {
|
pub struct DormandPrince45<const D: usize> {
|
||||||
a_tol: f64,
|
a_tol: SVector<f64,D>,
|
||||||
r_tol: f64,
|
r_tol: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -21,8 +21,17 @@ impl<const D: usize> DormandPrince45<D>
|
|||||||
where
|
where
|
||||||
DormandPrince45<D>: Integrator<D>,
|
DormandPrince45<D>: Integrator<D>,
|
||||||
{
|
{
|
||||||
pub fn new(a_tol: f64, r_tol: f64) -> Self {
|
pub fn new() -> Self {
|
||||||
Self { a_tol, r_tol }
|
Self { a_tol: SVector::<f64,D>::from_element(1e-8), r_tol: 1e-8 }
|
||||||
|
}
|
||||||
|
pub fn a_tol(&mut self, a_tol: f64) -> Self {
|
||||||
|
Self { a_tol: SVector::<f64,D>::from_element(a_tol), r_tol: self.r_tol }
|
||||||
|
}
|
||||||
|
pub fn a_tol_full(&mut self, a_tol: SVector::<f64,D>) -> Self {
|
||||||
|
Self { a_tol, r_tol: self.r_tol }
|
||||||
|
}
|
||||||
|
pub fn r_tol(&mut self, r_tol: f64) -> Self {
|
||||||
|
Self { a_tol: self.a_tol, r_tol }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,7 +102,7 @@ where
|
|||||||
h: f64,
|
h: f64,
|
||||||
) -> (SVector<f64, D>, Option<f64>, Option<Vec<SVector<f64, D>>>) {
|
) -> (SVector<f64, D>, Option<f64>, Option<Vec<SVector<f64, D>>>) {
|
||||||
let mut k: Vec<SVector<f64, D>> = vec![SVector::<f64, D>::zeros(); Self::STAGES];
|
let mut k: Vec<SVector<f64, D>> = vec![SVector::<f64, D>::zeros(); Self::STAGES];
|
||||||
let mut next_y = ode.y.clone();
|
let mut next_y = ode.y;
|
||||||
let mut err = SVector::<f64, D>::zeros();
|
let mut err = SVector::<f64, D>::zeros();
|
||||||
let mut rcont5 = SVector::<f64, D>::zeros();
|
let mut rcont5 = SVector::<f64, D>::zeros();
|
||||||
// Do the first of the summations
|
// Do the first of the summations
|
||||||
@@ -106,8 +115,8 @@ where
|
|||||||
for i in 1..Self::STAGES {
|
for i in 1..Self::STAGES {
|
||||||
// Compute the ks
|
// Compute the ks
|
||||||
let mut y_term = SVector::<f64, D>::zeros();
|
let mut y_term = SVector::<f64, D>::zeros();
|
||||||
for j in 0..i {
|
for (j, item) in k.iter().enumerate().take(i) {
|
||||||
y_term += k[j] * Self::A[(i * (i - 1)) / 2 + j];
|
y_term += item * Self::A[(i * (i - 1)) / 2 + j];
|
||||||
}
|
}
|
||||||
k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h, &ode.params);
|
k[i] = (ode.f)(ode.t + Self::C[i] * h, ode.y + y_term * h, &ode.params);
|
||||||
|
|
||||||
@@ -119,7 +128,7 @@ where
|
|||||||
let rcont2 = next_y - ode.y;
|
let rcont2 = next_y - ode.y;
|
||||||
let rcont3 = h * k[0] - rcont2;
|
let rcont3 = h * k[0] - rcont2;
|
||||||
let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3;
|
let rcont4 = rcont2 - k[Self::STAGES - 1] * h - rcont3;
|
||||||
let tol = SVector::<f64, D>::repeat(self.a_tol) + ode.y * self.r_tol;
|
let tol = self.a_tol + ode.y * self.r_tol;
|
||||||
let rcont = vec![rcont1, rcont2, rcont3, rcont4, rcont5];
|
let rcont = vec![rcont1, rcont2, rcont3, rcont4, rcont5];
|
||||||
(next_y, Some((err.component_div(&tol)).norm()), Some(rcont))
|
(next_y, Some((err.component_div(&tol)).norm()), Some(rcont))
|
||||||
}
|
}
|
||||||
@@ -127,7 +136,7 @@ where
|
|||||||
&self,
|
&self,
|
||||||
t_start: f64,
|
t_start: f64,
|
||||||
t_end: f64,
|
t_end: f64,
|
||||||
dense: &Vec<SVector<f64, D>>,
|
dense: &[SVector<f64, D>],
|
||||||
t: f64,
|
t: f64,
|
||||||
) -> SVector<f64, D> {
|
) -> SVector<f64, D> {
|
||||||
let s = (t - t_start) / (t_end - t_start);
|
let s = (t - t_start) / (t_end - t_start);
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ pub trait Integrator<const D: usize> {
|
|||||||
&self,
|
&self,
|
||||||
t_start: f64,
|
t_start: f64,
|
||||||
t_end: f64,
|
t_end: f64,
|
||||||
dense: &Vec<SVector<f64, D>>,
|
dense: &[SVector<f64, D>],
|
||||||
t: f64,
|
t: f64,
|
||||||
) -> SVector<f64, D>;
|
) -> SVector<f64, D>;
|
||||||
}
|
}
|
||||||
@@ -44,7 +44,7 @@ mod tests {
|
|||||||
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
||||||
let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ());
|
let mut ode = ODE::new(&derivative, 0.0, 4.0, y0, ());
|
||||||
|
|
||||||
let dp45 = DormandPrince45::new(1e-12_f64, 1e-4_f64);
|
let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-4);
|
||||||
|
|
||||||
// Test that y'(t) = y(t) solves to y(t) = e^t for rkf54
|
// Test that y'(t) = y(t) solves to y(t) = e^t for rkf54
|
||||||
// and also that the error seems reasonable
|
// and also that the error seems reasonable
|
||||||
|
|||||||
34
src/lib.rs
34
src/lib.rs
@@ -18,7 +18,7 @@ pub mod prelude {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use crate::prelude::*;
|
use crate::prelude::*;
|
||||||
use approx::assert_relative_eq;
|
use approx::assert_relative_eq;
|
||||||
use nalgebra::{Vector2, Vector6};
|
use nalgebra::{Vector1, Vector2, Vector6};
|
||||||
use std::f64::consts::PI;
|
use std::f64::consts::PI;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -38,7 +38,7 @@ mod tests {
|
|||||||
|
|
||||||
// Set up the problem (ODE, Integrator, Controller, and Callbacks)
|
// Set up the problem (ODE, Integrator, Controller, and Callbacks)
|
||||||
let ode = ODE::new(&derivative, 0.0, 6.3, y0, params);
|
let ode = ODE::new(&derivative, 0.0, 6.3, y0, params);
|
||||||
let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64);
|
let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6);
|
||||||
let controller = PIController::default();
|
let controller = PIController::default();
|
||||||
|
|
||||||
let value_too_high = Callback {
|
let value_too_high = Callback {
|
||||||
@@ -54,6 +54,32 @@ mod tests {
|
|||||||
let _interpolated_answer = solution.interpolate(4.4);
|
let _interpolated_answer = solution.interpolate(4.4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_correctness() {
|
||||||
|
// Define the system (parameters, derivative, and initial state)
|
||||||
|
type Params = ();
|
||||||
|
let params = ();
|
||||||
|
|
||||||
|
fn derivative(_t: f64, y: Vector1<f64>, _p: &Params) -> Vector1<f64> {
|
||||||
|
Vector1::new(5.0 * y[0] - 3.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
let y0 = Vector1::new(1.0);
|
||||||
|
|
||||||
|
// Set up the problem (ODE, Integrator, Controller, and Callbacks)
|
||||||
|
let ode = ODE::new(&derivative, 2.0, 3.0, y0, params);
|
||||||
|
let dp45 = DormandPrince45::new();
|
||||||
|
let controller = PIController::default();
|
||||||
|
|
||||||
|
// Solve the problem
|
||||||
|
let mut problem = Problem::new(ode, dp45, controller);
|
||||||
|
let solution = problem.solve();
|
||||||
|
for (time, state) in solution.times.iter().zip(solution.states.iter()) {
|
||||||
|
let exact = 0.4 * (5.0 * (time - 2.0)).exp() + 0.6;
|
||||||
|
assert_relative_eq!(state[0], exact, max_relative = 1e-7);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_orbit() {
|
fn test_orbit() {
|
||||||
// Calculate one period
|
// Calculate one period
|
||||||
@@ -79,11 +105,9 @@ mod tests {
|
|||||||
|
|
||||||
// Integrate
|
// Integrate
|
||||||
let ode = ODE::new(&derivative, 0.0, 10.0 * period, y0, params);
|
let ode = ODE::new(&derivative, 0.0, 10.0 * period, y0, params);
|
||||||
let dp45 = DormandPrince45::new(1e-12_f64, 1e-12_f64);
|
let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-12);
|
||||||
let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01);
|
let controller = PIController::new(0.37, 0.04, 10.0, 0.2, 1000.0, 0.9, 0.01);
|
||||||
|
|
||||||
let mut problem = Problem::new(ode, dp45, controller);
|
let mut problem = Problem::new(ode, dp45, controller);
|
||||||
|
|
||||||
let solution = problem.solve();
|
let solution = problem.solve();
|
||||||
|
|
||||||
assert_relative_eq!(
|
assert_relative_eq!(
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
use nalgebra::SVector;
|
use nalgebra::SVector;
|
||||||
|
|
||||||
|
type ProblemFunction<'a, const D: usize, P> = &'a dyn Fn(f64, SVector<f64, D>, &P) -> SVector<f64, D>;
|
||||||
|
|
||||||
/// The basic ODE object that will be passed around. The type (T) and the size (D) will be
|
/// The basic ODE object that will be passed around. The type (T) and the size (D) will be
|
||||||
/// determined upon creation of the object
|
/// determined upon creation of the object
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
pub struct ODE<'a, const D: usize, P> {
|
pub struct ODE<'a, const D: usize, P> {
|
||||||
pub f: &'a dyn Fn(f64, SVector<f64, D>, &P) -> SVector<f64, D>,
|
pub f: ProblemFunction<'a, D, P>,
|
||||||
pub y: SVector<f64, D>,
|
pub y: SVector<f64, D>,
|
||||||
pub t: f64,
|
pub t: f64,
|
||||||
pub params: P,
|
pub params: P,
|
||||||
@@ -16,7 +18,7 @@ pub struct ODE<'a, const D: usize, P> {
|
|||||||
|
|
||||||
impl<'a, const D: usize, P> ODE<'a, D, P> {
|
impl<'a, const D: usize, P> ODE<'a, D, P> {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
f: &'a (dyn Fn(f64, SVector<f64, D>, &P) -> SVector<f64, D>),
|
f: ProblemFunction<'a, D, P>,
|
||||||
t0: f64,
|
t0: f64,
|
||||||
t_end: f64,
|
t_end: f64,
|
||||||
y0: SVector<f64, D>,
|
y0: SVector<f64, D>,
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use nalgebra::SVector;
|
use nalgebra::SVector;
|
||||||
use roots::{find_root_brent, DebugConvergency};
|
use roots::{find_root_brent, SimpleConvergency};
|
||||||
|
|
||||||
use super::callback::Callback;
|
use super::callback::Callback;
|
||||||
use super::controller::{Controller, PIController};
|
use super::controller::{Controller, PIController, TryStep};
|
||||||
use super::integrator::Integrator;
|
use super::integrator::Integrator;
|
||||||
use super::ode::ODE;
|
use super::ode::ODE;
|
||||||
|
|
||||||
@@ -23,48 +23,64 @@ where
|
|||||||
{
|
{
|
||||||
pub fn new(ode: ODE<'a, D, P>, integrator: S, controller: PIController) -> Self {
|
pub fn new(ode: ODE<'a, D, P>, integrator: S, controller: PIController) -> Self {
|
||||||
Problem {
|
Problem {
|
||||||
ode: ode,
|
ode,
|
||||||
integrator: integrator,
|
integrator,
|
||||||
controller: controller,
|
controller,
|
||||||
callbacks: Vec::new(),
|
callbacks: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub fn solve(&mut self) -> Solution<S, D> {
|
pub fn solve(&mut self) -> Solution<S, D> {
|
||||||
let mut convergency = DebugConvergency::new(1e-12, 50);
|
let mut convergency = SimpleConvergency {
|
||||||
|
eps: 1e-12,
|
||||||
|
max_iter: 1000,
|
||||||
|
};
|
||||||
let mut times: Vec<f64> = vec![self.ode.t];
|
let mut times: Vec<f64> = vec![self.ode.t];
|
||||||
let mut states: Vec<SVector<f64, D>> = vec![self.ode.y];
|
let mut states: Vec<SVector<f64, D>> = vec![self.ode.y];
|
||||||
let mut dense_coefficients: Vec<Vec<SVector<f64, D>>> = Vec::new();
|
let mut dense_coefficients: Vec<Vec<SVector<f64, D>>> = Vec::new();
|
||||||
let mut step: f64 = self.controller.old_h;
|
|
||||||
let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0);
|
|
||||||
while self.ode.t < self.ode.t_end {
|
while self.ode.t < self.ode.t_end {
|
||||||
let mut dense_option: Option<Vec<SVector<f64, D>>> = None;
|
if self.ode.t + self.controller.next_step_guess.extract() > self.ode.t_end {
|
||||||
if S::ADAPTIVE {
|
// If the next step would go past the end, then just set it to the end
|
||||||
|
self.controller.next_step_guess = TryStep::NotYetAccepted(
|
||||||
|
self.ode.t_end - self.ode.t,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let (mut new_y, mut curr_step, mut dense_option) = if S::ADAPTIVE {
|
||||||
|
// First, we try stepping with the "next step guess" to get the error
|
||||||
|
let (mut trial_y, mut err_option, mut dense_option) =
|
||||||
|
self.integrator.step(&self.ode, self.controller.next_step_guess.extract());
|
||||||
let mut err = err_option.unwrap();
|
let mut err = err_option.unwrap();
|
||||||
let mut accepted: bool = false;
|
// Then we determine whether we need to reduce the step size or not
|
||||||
while !accepted {
|
// If successful, we get the next step guess
|
||||||
// Try a step and if that isn't acceptable, then change the step until it is
|
let initial_guess = self.controller.next_step_guess.extract();
|
||||||
(accepted, step) = <PIController as Controller<D>>::determine_step(
|
let mut next_step_guess = <PIController as Controller<D>>::determine_step(
|
||||||
&mut self.controller,
|
&mut self.controller,
|
||||||
step,
|
initial_guess,
|
||||||
|
err,
|
||||||
|
);
|
||||||
|
while !next_step_guess.is_accepted() {
|
||||||
|
// If that step isn't acceptable, then change the step until it is
|
||||||
|
(trial_y, err_option, dense_option) =
|
||||||
|
self.integrator.step(&self.ode, next_step_guess.extract());
|
||||||
|
next_step_guess = <PIController as Controller<D>>::determine_step(
|
||||||
|
&mut self.controller,
|
||||||
|
next_step_guess.extract(),
|
||||||
err,
|
err,
|
||||||
);
|
);
|
||||||
(new_y, err_option, dense_option) = self.integrator.step(&self.ode, step);
|
|
||||||
err = err_option.unwrap();
|
err = err_option.unwrap();
|
||||||
}
|
}
|
||||||
self.controller.old_h = step;
|
// So at this point we can safely assume we have an accepted step
|
||||||
self.controller.h_max = self
|
self.controller.next_step_guess = next_step_guess.reset().unwrap();
|
||||||
.controller
|
(trial_y, next_step_guess.extract(), dense_option)
|
||||||
.h_max
|
|
||||||
.min(self.ode.t_end - self.ode.t - step);
|
|
||||||
} else {
|
} else {
|
||||||
// If fixed time step just step forward one step
|
// If fixed time step just step forward one step
|
||||||
(new_y, _, dense_option) = self.integrator.step(&self.ode, step);
|
let (trial_y, _, dense_option) = self.integrator.step(&self.ode, self.controller.next_step_guess.extract());
|
||||||
}
|
(trial_y, self.controller.next_step_guess.extract(), dense_option)
|
||||||
if self.callbacks.len() > 0 {
|
};
|
||||||
|
if !self.callbacks.is_empty() {
|
||||||
// Check for events occurring
|
// Check for events occurring
|
||||||
for callback in &self.callbacks {
|
for callback in &self.callbacks {
|
||||||
if (callback.event)(self.ode.t, self.ode.y, &self.ode.params)
|
if (callback.event)(self.ode.t, self.ode.y, &self.ode.params)
|
||||||
* (callback.event)(self.ode.t + step, new_y, &self.ode.params)
|
* (callback.event)(self.ode.t + curr_step, new_y, &self.ode.params)
|
||||||
< 0.0
|
< 0.0
|
||||||
{
|
{
|
||||||
// If the event crossed zero, then find the root
|
// If the event crossed zero, then find the root
|
||||||
@@ -72,15 +88,15 @@ where
|
|||||||
let test_y = self.integrator.step(&self.ode, test_t).0;
|
let test_y = self.integrator.step(&self.ode, test_t).0;
|
||||||
(callback.event)(self.ode.t + test_t, test_y, &self.ode.params)
|
(callback.event)(self.ode.t + test_t, test_y, &self.ode.params)
|
||||||
};
|
};
|
||||||
let root = find_root_brent(0.0, step, &f, &mut convergency).unwrap();
|
let root = find_root_brent(0.0, curr_step, &f, &mut convergency).unwrap();
|
||||||
step = root;
|
curr_step = root;
|
||||||
(new_y, _, dense_option) = self.integrator.step(&self.ode, step);
|
(new_y, _, dense_option) = self.integrator.step(&self.ode, curr_step);
|
||||||
(callback.effect)(&mut self.ode);
|
(callback.effect)(&mut self.ode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.ode.y = new_y;
|
self.ode.y = new_y;
|
||||||
self.ode.t += step;
|
self.ode.t += curr_step;
|
||||||
times.push(self.ode.t);
|
times.push(self.ode.t);
|
||||||
states.push(self.ode.y);
|
states.push(self.ode.y);
|
||||||
// TODO: Implement third order interpolation for non-dense algorithms
|
// TODO: Implement third order interpolation for non-dense algorithms
|
||||||
@@ -134,20 +150,17 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Then find the two t values closest to the desired t
|
// Then find the two t values closest to the desired t
|
||||||
let mut end_index: usize = 0;
|
match times.binary_search_by(|x| x.total_cmp(&t)) {
|
||||||
for (i, time) in self.times.iter().enumerate() {
|
Ok(index) => self.states[index],
|
||||||
if time > &t {
|
Err(end_index) => {
|
||||||
end_index = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then send that to the integrator
|
// Then send that to the integrator
|
||||||
let t_start = times[end_index - 1];
|
let t_start = times[end_index - 1];
|
||||||
let t_end = times[end_index];
|
let t_end = times[end_index];
|
||||||
self.integrator
|
self.integrator
|
||||||
.interpolate(t_start, t_end, &self.dense[end_index - 1], t)
|
.interpolate(t_start, t_end, &self.dense[end_index - 1], t)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -168,7 +181,7 @@ mod tests {
|
|||||||
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
||||||
|
|
||||||
let ode = ODE::new(&derivative, 0.0, 1.0, y0, ());
|
let ode = ODE::new(&derivative, 0.0, 1.0, y0, ());
|
||||||
let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64);
|
let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5);
|
||||||
let controller = PIController::default();
|
let controller = PIController::default();
|
||||||
|
|
||||||
let mut problem = Problem::new(ode, dp45, controller);
|
let mut problem = Problem::new(ode, dp45, controller);
|
||||||
@@ -192,7 +205,7 @@ mod tests {
|
|||||||
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
||||||
|
|
||||||
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
|
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
|
||||||
let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64);
|
let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-5);
|
||||||
let controller = PIController::default();
|
let controller = PIController::default();
|
||||||
|
|
||||||
let value_too_high = Callback {
|
let value_too_high = Callback {
|
||||||
@@ -203,7 +216,11 @@ mod tests {
|
|||||||
let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high);
|
let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high);
|
||||||
let solution = problem.solve();
|
let solution = problem.solve();
|
||||||
|
|
||||||
assert!(solution.states.last().unwrap()[0] == 10.0);
|
assert_relative_eq!(
|
||||||
|
solution.states.last().unwrap()[0],
|
||||||
|
10.0,
|
||||||
|
max_relative = 1e-11
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -215,7 +232,7 @@ mod tests {
|
|||||||
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
||||||
|
|
||||||
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
|
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
|
||||||
let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64);
|
let dp45 = DormandPrince45::new().a_tol(1e-12).r_tol(1e-6);
|
||||||
let controller = PIController::default();
|
let controller = PIController::default();
|
||||||
|
|
||||||
let mut problem = Problem::new(ode, dp45, controller);
|
let mut problem = Problem::new(ode, dp45, controller);
|
||||||
|
|||||||
Reference in New Issue
Block a user