use nalgebra::SVector; use roots::find_root_regula_falsi; use super::ode::ODE; use super::controller::{Controller, PIController}; use super::integrator::Integrator; use super::callback::Callback; #[derive(Clone)] pub struct Problem<'a, const D: usize, S> where S: Integrator, { ode: ODE<'a, D>, integrator: S, controller: PIController, callbacks: Vec>, } impl<'a, const D: usize, S> Problem<'a,D,S> where S: Integrator + Copy, { pub fn new(ode: ODE<'a,D>, integrator: S, controller: PIController) -> Self { Problem { ode: ode, integrator: integrator, controller: controller, callbacks: Vec::new(), } } pub fn solve(&mut self) -> Solution { let mut times: Vec:: = vec![self.ode.t]; let mut states: Vec::> = vec![self.ode.y]; let mut dense_coefficients: Vec::>> = Vec::new(); let mut step: f64 = self.controller.old_h; let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0); while self.ode.t < self.ode.t_end { let mut dense_option: Option>> = None; if S::ADAPTIVE { let mut err = err_option.unwrap(); let mut accepted: bool = false; while !accepted { // Try a step and if that isn't acceptable, then change the step until it is (accepted, step) = >::determine_step(&mut self.controller, step, err); (new_y, err_option, dense_option) = self.integrator.step(&self.ode, step); err = err_option.unwrap(); } self.controller.old_h = step; self.controller.h_max = self.controller.h_max.min(self.ode.t_end - self.ode.t - step); } else { // If fixed time step just step forward one step (new_y, _, dense_option) = self.integrator.step(&self.ode, step); } if self.callbacks.len() > 0 { // Check for events occurring for callback in &self.callbacks { println!("{}", (callback.event)(self.ode.t, self.ode.y) * (callback.event)(self.ode.t + step, new_y)); if (callback.event)(self.ode.t, self.ode.y) * (callback.event)(self.ode.t + step, new_y) < 0.0 { // If the event crossed zero, then find the root let f = |test_t| { let test_y = self.integrator.step(&self.ode, test_t).0; (callback.event)(self.ode.t + test_t, test_y) }; let root = find_root_regula_falsi(0.0, step, &f, &mut 1e-12).unwrap(); step = root; (new_y, _, dense_option) = self.integrator.step(&self.ode, step); self.ode = (callback.effect)(self.ode); } } } self.ode.y = new_y; self.ode.t += step; times.push(self.ode.t); states.push(self.ode.y); // TODO: Implement third order interpolation for non-dense algorithms dense_coefficients.push(dense_option.unwrap()); } Solution { integrator: self.integrator, times: times, states: states, dense: dense_coefficients, } } pub fn with_callback(mut self, callback: Callback<'a, D>) -> Self { self.callbacks.push(callback); Self { ode: self.ode, integrator: self.integrator, controller: self.controller, callbacks: self.callbacks, } } } pub struct Solution where S: Integrator { pub integrator: S, pub times: Vec, pub states: Vec>, pub dense: Vec::>>, } impl Solution where S: Integrator { pub fn interpolate(&self, t: f64) -> SVector { // First check that the t is within bounds let last = self.times.last().unwrap(); let first = self.times.first().unwrap(); // TODO: Improve these errors let mut times = self.times.clone(); if *first > *last { times.reverse(); } if t < *first || t > *last { panic!(); } // Then find the two t values closest to the desired t let mut end_index: usize = 0; for (i, time) in self.times.iter().enumerate() { if time > &t { end_index = i; break; } } // Then send that to the integrator let t_start = times[end_index - 1]; let t_end = times[end_index]; self.integrator.interpolate(t_start, t_end, &self.dense[end_index - 1], t) } } #[cfg(test)] mod tests { use super::*; use nalgebra::Vector3; use approx::assert_relative_eq; use crate::integrator::dormand_prince::DormandPrince45; use crate::controller::PIController; use crate::callback::stop; #[test] fn test_problem() { fn derivative(_t: f64, y: Vector3) -> Vector3 { y } let y0 = Vector3::new(1.0, 1.0, 1.0); let ode = ODE::new(&derivative, 0.0, 1.0, y0); let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); let controller = PIController::new(0.17, 0.04, 10.0, 0.2, 0.1, 0.9, 1e-8); let mut problem = Problem::new(ode, dp45, controller); let solution = problem.solve(); solution.times.iter().zip(solution.states.iter()).for_each(|(time, state)| { assert_relative_eq!(state[0], time.exp(), max_relative=1e-2); }) } #[test] fn test_with_callback() { fn derivative(_t: f64, y: Vector3) -> Vector3 { y } let y0 = Vector3::new(1.0, 1.0, 1.0); let ode = ODE::new(&derivative, 0.0, 5.0, y0); let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64); let controller = PIController::new(0.17, 0.04, 10.0, 0.2, 0.1, 0.9, 1e-8); let value_too_high = Callback { event: &|_: f64, y: SVector| { 10.0 - y[0] }, effect: &stop, }; let mut problem = Problem::new(ode, dp45, controller).with_callback(value_too_high); let solution = problem.solve(); println!("{}", solution.states.last().unwrap()[0]); assert!(solution.states.last().unwrap()[0] == 10.0); } #[test] fn test_with_interpolation() { fn derivative(_t: f64, y: Vector3) -> Vector3 { y } let y0 = Vector3::new(1.0, 1.0, 1.0); let ode = ODE::new(&derivative, 0.0, 10.0, y0); let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64); let controller = PIController::new(0.17, 0.04, 10.0, 0.2, 0.1, 0.9, 1e-8); let mut problem = Problem::new(ode, dp45, controller); let solution = problem.solve(); assert_relative_eq!(solution.interpolate(8.8)[0], 8.8_f64.exp(), max_relative=1e-6); } }