194 lines
7.0 KiB
Rust
194 lines
7.0 KiB
Rust
use nalgebra::SVector;
|
|
use roots::find_root_regula_falsi;
|
|
|
|
use super::ode::ODE;
|
|
use super::controller::{Controller, PIController};
|
|
use super::integrator::Integrator;
|
|
use super::callback::Callback;
|
|
|
|
#[derive(Clone)]
|
|
pub struct Problem<'a, const D: usize, S>
|
|
where
|
|
S: Integrator<D>,
|
|
{
|
|
ode: ODE<'a, D>,
|
|
integrator: S,
|
|
controller: PIController,
|
|
callbacks: Vec<Callback<'a, D>>,
|
|
}
|
|
|
|
impl<'a, const D: usize, S> Problem<'a,D,S>
|
|
where
|
|
S: Integrator<D> + Copy,
|
|
{
|
|
pub fn new(ode: ODE<'a,D>, integrator: S, controller: PIController) -> Self {
|
|
Problem {
|
|
ode: ode,
|
|
integrator: integrator,
|
|
controller: controller,
|
|
callbacks: Vec::new(),
|
|
}
|
|
}
|
|
pub fn solve(&mut self) -> Solution<S, D> {
|
|
let mut times: Vec::<f64> = vec![self.ode.t];
|
|
let mut states: Vec::<SVector<f64,D>> = vec![self.ode.y];
|
|
let mut dense_coefficients: Vec::<Vec<SVector<f64,D>>> = Vec::new();
|
|
let mut step: f64 = self.controller.old_h;
|
|
let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0);
|
|
while self.ode.t < self.ode.t_end {
|
|
let mut dense_option: Option<Vec<SVector<f64,D>>> = None;
|
|
if S::ADAPTIVE {
|
|
let mut err = err_option.unwrap();
|
|
let mut accepted: bool = false;
|
|
while !accepted {
|
|
// Try a step and if that isn't acceptable, then change the step until it is
|
|
(accepted, step) = <PIController as Controller<D>>::determine_step(&mut self.controller, step, err);
|
|
(new_y, err_option, dense_option) = self.integrator.step(&self.ode, step);
|
|
err = err_option.unwrap();
|
|
}
|
|
self.controller.old_h = step;
|
|
self.controller.h_max = self.controller.h_max.min(self.ode.t_end - self.ode.t - step);
|
|
} else {
|
|
// If fixed time step just step forward one step
|
|
(new_y, _, dense_option) = self.integrator.step(&self.ode, step);
|
|
}
|
|
if self.callbacks.len() > 0 {
|
|
// Check for events occurring
|
|
for callback in &self.callbacks {
|
|
println!("{}", (callback.event)(self.ode.t, self.ode.y) * (callback.event)(self.ode.t + step, new_y));
|
|
if (callback.event)(self.ode.t, self.ode.y) * (callback.event)(self.ode.t + step, new_y) < 0.0 {
|
|
// If the event crossed zero, then find the root
|
|
let f = |test_t| {
|
|
let test_y = self.integrator.step(&self.ode, test_t).0;
|
|
(callback.event)(self.ode.t + test_t, test_y)
|
|
};
|
|
let root = find_root_regula_falsi(0.0, step, &f, &mut 1e-12).unwrap();
|
|
step = root;
|
|
(new_y, _, dense_option) = self.integrator.step(&self.ode, step);
|
|
self.ode = (callback.effect)(self.ode);
|
|
}
|
|
}
|
|
}
|
|
self.ode.y = new_y;
|
|
self.ode.t += step;
|
|
times.push(self.ode.t);
|
|
states.push(self.ode.y);
|
|
// TODO: Implement third order interpolation for non-dense algorithms
|
|
dense_coefficients.push(dense_option.unwrap());
|
|
}
|
|
Solution {
|
|
integrator: self.integrator,
|
|
times: times,
|
|
states: states,
|
|
dense: dense_coefficients,
|
|
}
|
|
}
|
|
|
|
pub fn with_callback(mut self, callback: Callback<'a, D>) -> Self {
|
|
self.callbacks.push(callback);
|
|
Self {
|
|
ode: self.ode,
|
|
integrator: self.integrator,
|
|
controller: self.controller,
|
|
callbacks: self.callbacks,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Solution<S, const D: usize> where S: Integrator<D> {
|
|
pub integrator: S,
|
|
pub times: Vec<f64>,
|
|
pub states: Vec<SVector<f64,D>>,
|
|
pub dense: Vec::<Vec<SVector<f64,D>>>,
|
|
}
|
|
|
|
impl<S, const D: usize> Solution<S,D> where S: Integrator<D> {
|
|
pub fn interpolate(&self, t: f64) -> SVector<f64, D> {
|
|
// First check that the t is within bounds
|
|
let last = self.times.last().unwrap();
|
|
let first = self.times.first().unwrap();
|
|
|
|
// TODO: Improve these errors
|
|
let mut times = self.times.clone();
|
|
if *first > *last { times.reverse(); }
|
|
if t < *first || t > *last { panic!(); }
|
|
|
|
// Then find the two t values closest to the desired t
|
|
let mut end_index: usize = 0;
|
|
for (i, time) in self.times.iter().enumerate() {
|
|
if time > &t {
|
|
end_index = i;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Then send that to the integrator
|
|
let t_start = times[end_index - 1];
|
|
let t_end = times[end_index];
|
|
self.integrator.interpolate(t_start, t_end, &self.dense[end_index - 1], t)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use nalgebra::Vector3;
|
|
use approx::assert_relative_eq;
|
|
use crate::integrator::dormand_prince::DormandPrince45;
|
|
use crate::controller::PIController;
|
|
use crate::callback::stop;
|
|
|
|
#[test]
|
|
fn test_problem() {
|
|
fn derivative(_t: f64, y: Vector3<f64>) -> Vector3<f64> { y }
|
|
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
|
|
|
let ode = ODE::new(&derivative, 0.0, 1.0, y0);
|
|
let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64);
|
|
let controller = PIController::new(0.17, 0.04, 10.0, 0.2, 0.1, 0.9, 1e-8);
|
|
|
|
let mut problem = Problem::new(ode, dp45, controller);
|
|
|
|
let solution = problem.solve();
|
|
solution.times.iter().zip(solution.states.iter()).for_each(|(time, state)| {
|
|
assert_relative_eq!(state[0], time.exp(), max_relative=1e-2);
|
|
})
|
|
}
|
|
|
|
#[test]
|
|
fn test_with_callback() {
|
|
fn derivative(_t: f64, y: Vector3<f64>) -> Vector3<f64> { y }
|
|
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
|
|
|
let ode = ODE::new(&derivative, 0.0, 5.0, y0);
|
|
let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64);
|
|
let controller = PIController::new(0.17, 0.04, 10.0, 0.2, 0.1, 0.9, 1e-8);
|
|
|
|
let value_too_high = Callback {
|
|
event: &|_: f64, y: SVector<f64,3>| { 10.0 - y[0] },
|
|
effect: &stop,
|
|
};
|
|
|
|
let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high);
|
|
let solution = problem.solve();
|
|
|
|
println!("{}", solution.states.last().unwrap()[0]);
|
|
assert!(solution.states.last().unwrap()[0] == 10.0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_with_interpolation() {
|
|
fn derivative(_t: f64, y: Vector3<f64>) -> Vector3<f64> { y }
|
|
let y0 = Vector3::new(1.0, 1.0, 1.0);
|
|
|
|
let ode = ODE::new(&derivative, 0.0, 10.0, y0);
|
|
let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64);
|
|
let controller = PIController::new(0.17, 0.04, 10.0, 0.2, 0.1, 0.9, 1e-8);
|
|
|
|
let mut problem = Problem::new(ode, dp45, controller);
|
|
let solution = problem.solve();
|
|
|
|
assert_relative_eq!(solution.interpolate(8.8)[0], 8.8_f64.exp(), max_relative=1e-6);
|
|
}
|
|
}
|