Added parameters

This commit is contained in:
Connor Johnstone
2023-03-15 15:28:56 -06:00
parent 8daee11ae7
commit 63b603a57c
7 changed files with 123 additions and 48 deletions

View File

@@ -7,21 +7,21 @@ use super::integrator::Integrator;
use super::callback::Callback;
#[derive(Clone)]
pub struct Problem<'a, const D: usize, S>
pub struct Problem<'a, const D: usize, S, P>
where
S: Integrator<D>,
{
ode: ODE<'a, D>,
ode: ODE<'a, D, P>,
integrator: S,
controller: PIController,
callbacks: Vec<Callback<'a, D>>,
callbacks: Vec<Callback<'a, D, P>>,
}
impl<'a, const D: usize, S> Problem<'a,D,S>
impl<'a, const D: usize, S, P> Problem<'a,D,S,P>
where
S: Integrator<D> + Copy,
{
pub fn new(ode: ODE<'a,D>, integrator: S, controller: PIController) -> Self {
pub fn new(ode: ODE<'a,D,P>, integrator: S, controller: PIController) -> Self {
Problem {
ode: ode,
integrator: integrator,
@@ -55,17 +55,17 @@ where
if self.callbacks.len() > 0 {
// Check for events occurring
for callback in &self.callbacks {
println!("{}", (callback.event)(self.ode.t, self.ode.y) * (callback.event)(self.ode.t + step, new_y));
if (callback.event)(self.ode.t, self.ode.y) * (callback.event)(self.ode.t + step, new_y) < 0.0 {
println!("{}", (callback.event)(self.ode.t, self.ode.y, &self.ode.params) * (callback.event)(self.ode.t + step, new_y, &self.ode.params));
if (callback.event)(self.ode.t, self.ode.y, &self.ode.params) * (callback.event)(self.ode.t + step, new_y, &self.ode.params) < 0.0 {
// If the event crossed zero, then find the root
let f = |test_t| {
let test_y = self.integrator.step(&self.ode, test_t).0;
(callback.event)(self.ode.t + test_t, test_y)
(callback.event)(self.ode.t + test_t, test_y, &self.ode.params)
};
let root = find_root_regula_falsi(0.0, step, &f, &mut 1e-12).unwrap();
step = root;
(new_y, _, dense_option) = self.integrator.step(&self.ode, step);
self.ode = (callback.effect)(self.ode);
(callback.effect)(&mut self.ode);
}
}
}
@@ -84,7 +84,7 @@ where
}
}
pub fn with_callback(mut self, callback: Callback<'a, D>) -> Self {
pub fn with_callback(mut self, callback: Callback<'a, D, P>) -> Self {
self.callbacks.push(callback);
Self {
ode: self.ode,
@@ -140,10 +140,11 @@ mod tests {
#[test]
fn test_problem() {
fn derivative(_t: f64, y: Vector3<f64>) -> Vector3<f64> { y }
type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { y }
let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 1.0, y0);
let ode = ODE::new(&derivative, 0.0, 1.0, y0, ());
let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64);
let controller = PIController::default();
@@ -157,15 +158,16 @@ mod tests {
#[test]
fn test_with_callback() {
fn derivative(_t: f64, y: Vector3<f64>) -> Vector3<f64> { y }
type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { y }
let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 5.0, y0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
let dp45 = DormandPrince45::new(1e-12_f64, 1e-5_f64);
let controller = PIController::default();
let value_too_high = Callback {
event: &|_: f64, y: SVector<f64,3>| { 10.0 - y[0] },
event: &|_: f64, y: SVector<f64,3>, _: &Params| { 10.0 - y[0] },
effect: &stop,
};
@@ -178,10 +180,11 @@ mod tests {
#[test]
fn test_with_interpolation() {
fn derivative(_t: f64, y: Vector3<f64>) -> Vector3<f64> { y }
type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { y }
let y0 = Vector3::new(1.0, 1.0, 1.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
let dp45 = DormandPrince45::new(1e-12_f64, 1e-6_f64);
let controller = PIController::default();