Formatting changes

This commit is contained in:
Connor Johnstone
2023-10-16 14:16:44 -06:00
parent 16e958c5db
commit 9d82f92c07
7 changed files with 194 additions and 103 deletions

View File

@@ -1,10 +1,10 @@
use nalgebra::SVector;
use roots::find_root_regula_falsi;
use super::ode::ODE;
use super::callback::Callback;
use super::controller::{Controller, PIController};
use super::integrator::Integrator;
use super::callback::Callback;
use super::ode::ODE;
#[derive(Clone)]
pub struct Problem<'a, const D: usize, S, P>
@@ -17,11 +17,11 @@ where
callbacks: Vec<Callback<'a, D, P>>,
}
impl<'a, const D: usize, S, P> Problem<'a,D,S,P>
impl<'a, const D: usize, S, P> Problem<'a, D, S, P>
where
S: Integrator<D> + Copy,
{
pub fn new(ode: ODE<'a,D,P>, integrator: S, controller: PIController) -> Self {
pub fn new(ode: ODE<'a, D, P>, integrator: S, controller: PIController) -> Self {
Problem {
ode: ode,
integrator: integrator,
@@ -30,24 +30,31 @@ where
}
}
pub fn solve(&mut self) -> Solution<S, D> {
let mut times: Vec::<f64> = vec![self.ode.t];
let mut states: Vec::<SVector<f64,D>> = vec![self.ode.y];
let mut dense_coefficients: Vec::<Vec<SVector<f64,D>>> = Vec::new();
let mut times: Vec<f64> = vec![self.ode.t];
let mut states: Vec<SVector<f64, D>> = vec![self.ode.y];
let mut dense_coefficients: Vec<Vec<SVector<f64, D>>> = Vec::new();
let mut step: f64 = self.controller.old_h;
let (mut new_y, mut err_option, _) = self.integrator.step(&self.ode, 0.0);
while self.ode.t < self.ode.t_end {
let mut dense_option: Option<Vec<SVector<f64,D>>> = None;
let mut dense_option: Option<Vec<SVector<f64, D>>> = None;
if S::ADAPTIVE {
let mut err = err_option.unwrap();
let mut accepted: bool = false;
while !accepted {
// Try a step and if that isn't acceptable, then change the step until it is
(accepted, step) = <PIController as Controller<D>>::determine_step(&mut self.controller, step, err);
(accepted, step) = <PIController as Controller<D>>::determine_step(
&mut self.controller,
step,
err,
);
(new_y, err_option, dense_option) = self.integrator.step(&self.ode, step);
err = err_option.unwrap();
}
self.controller.old_h = step;
self.controller.h_max = self.controller.h_max.min(self.ode.t_end - self.ode.t - step);
self.controller.h_max = self
.controller
.h_max
.min(self.ode.t_end - self.ode.t - step);
} else {
// If fixed time step just step forward one step
(new_y, _, dense_option) = self.integrator.step(&self.ode, step);
@@ -55,7 +62,10 @@ where
if self.callbacks.len() > 0 {
// Check for events occurring
for callback in &self.callbacks {
if (callback.event)(self.ode.t, self.ode.y, &self.ode.params) * (callback.event)(self.ode.t + step, new_y, &self.ode.params) < 0.0 {
if (callback.event)(self.ode.t, self.ode.y, &self.ode.params)
* (callback.event)(self.ode.t + step, new_y, &self.ode.params)
< 0.0
{
// If the event crossed zero, then find the root
let f = |test_t| {
let test_y = self.integrator.step(&self.ode, test_t).0;
@@ -77,8 +87,8 @@ where
}
Solution {
integrator: self.integrator,
times: times,
states: states,
times,
states,
dense: dense_coefficients,
}
}
@@ -94,14 +104,20 @@ where
}
}
pub struct Solution<S, const D: usize> where S: Integrator<D> {
pub struct Solution<S, const D: usize>
where
S: Integrator<D>,
{
pub integrator: S,
pub times: Vec<f64>,
pub states: Vec<SVector<f64,D>>,
pub dense: Vec::<Vec<SVector<f64,D>>>,
pub states: Vec<SVector<f64, D>>,
pub dense: Vec<Vec<SVector<f64, D>>>,
}
impl<S, const D: usize> Solution<S,D> where S: Integrator<D> {
impl<S, const D: usize> Solution<S, D>
where
S: Integrator<D>,
{
pub fn interpolate(&self, t: f64) -> SVector<f64, D> {
// First check that the t is within bounds
let last = self.times.last().unwrap();
@@ -109,8 +125,12 @@ impl<S, const D: usize> Solution<S,D> where S: Integrator<D> {
// TODO: Improve these errors
let mut times = self.times.clone();
if *first > *last { times.reverse(); }
if t < *first || t > *last { panic!(); }
if *first > *last {
times.reverse();
}
if t < *first || t > *last {
panic!();
}
// Then find the two t values closest to the desired t
let mut end_index: usize = 0;
@@ -124,23 +144,26 @@ impl<S, const D: usize> Solution<S,D> where S: Integrator<D> {
// Then send that to the integrator
let t_start = times[end_index - 1];
let t_end = times[end_index];
self.integrator.interpolate(t_start, t_end, &self.dense[end_index - 1], t)
self.integrator
.interpolate(t_start, t_end, &self.dense[end_index - 1], t)
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::Vector3;
use approx::assert_relative_eq;
use crate::integrator::dormand_prince::DormandPrince45;
use crate::controller::PIController;
use crate::callback::stop;
use crate::controller::PIController;
use crate::integrator::dormand_prince::DormandPrince45;
use approx::assert_relative_eq;
use nalgebra::Vector3;
#[test]
fn test_problem() {
type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { y }
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
y
}
let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 1.0, y0, ());
@@ -150,15 +173,21 @@ mod tests {
let mut problem = Problem::new(ode, dp45, controller);
let solution = problem.solve();
solution.times.iter().zip(solution.states.iter()).for_each(|(time, state)| {
assert_relative_eq!(state[0], time.exp(), max_relative=1e-2);
})
solution
.times
.iter()
.zip(solution.states.iter())
.for_each(|(time, state)| {
assert_relative_eq!(state[0], time.exp(), max_relative = 1e-2);
})
}
#[test]
fn test_with_callback() {
type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { y }
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
y
}
let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
@@ -166,7 +195,7 @@ mod tests {
let controller = PIController::default();
let value_too_high = Callback {
event: &|_: f64, y: SVector<f64,3>, _: &Params| { 10.0 - y[0] },
event: &|_: f64, y: SVector<f64, 3>, _: &Params| 10.0 - y[0],
effect: &stop,
};
@@ -179,7 +208,9 @@ mod tests {
#[test]
fn test_with_interpolation() {
type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { y }
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> {
y
}
let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
@@ -189,6 +220,10 @@ mod tests {
let mut problem = Problem::new(ode, dp45, controller);
let solution = problem.solve();
assert_relative_eq!(solution.interpolate(8.8)[0], 8.8_f64.exp(), max_relative=1e-6);
assert_relative_eq!(
solution.interpolate(8.8)[0],
8.8_f64.exp(),
max_relative = 1e-6
);
}
}