Added parameters

This commit is contained in:
Connor Johnstone
2023-03-15 15:28:56 -06:00
parent 8daee11ae7
commit 63b603a57c
7 changed files with 123 additions and 48 deletions

View File

@@ -3,22 +3,30 @@ use nalgebra::SVector;
/// The basic ODE object that will be passed around. The type (T) and the size (D) will be
/// determined upon creation of the object
#[derive(Clone, Copy)]
pub struct ODE<'a, const D: usize> {
pub f: &'a dyn Fn(f64, SVector<f64,D>) -> SVector<f64,D>,
pub struct ODE<'a, const D: usize, P> {
pub f: &'a dyn Fn(f64, SVector<f64,D>, &P) -> SVector<f64,D>,
pub y: SVector<f64,D>,
pub t: f64,
pub params: P,
pub t0: f64,
pub t_end: f64,
pub h: f64,
pub finished: bool,
}
impl<'a, const D: usize> ODE<'a,D> {
pub fn new(f: &'a (dyn Fn(f64, SVector<f64,D>) -> SVector<f64,D>), t0: f64, t_end: f64, y0: SVector<f64,D>) -> Self {
impl<'a, const D: usize, P> ODE<'a,D, P> {
pub fn new(
f: &'a (dyn Fn(f64, SVector<f64,D>, &P) -> SVector<f64,D>),
t0: f64,
t_end: f64,
y0: SVector<f64,D>,
params: P,
) -> Self {
Self {
f: f,
y: y0,
t: t0,
params: params,
t0: t0,
t_end: t_end,
h: 0.001,
@@ -35,12 +43,32 @@ mod tests {
#[test]
fn test_ode_creation() {
fn derivative(_t: f64, y: Vector3<f64>) -> Vector3<f64> { -y }
type Params = ();
fn derivative(_t: f64, y: Vector3<f64>, _p: &Params) -> Vector3<f64> { -y }
let y0 = Vector3::new(1.0, 0.0, 0.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, ());
assert!((ode.f)(0.0, y0) == Vector3::new(-1.0, 0.0, 0.0));
assert!((ode.f)(0.0, y0, &()) == Vector3::new(-1.0, 0.0, 0.0));
assert!(ode.y == Vector3::new(1.0, 0.0, 0.0));
assert!(ode.t == 0.0);
assert!(!ode.finished);
assert!(ode.t_end == 10.0);
}
#[test]
fn test_ode_with_params() {
type Params = (f64, bool);
let params = (34.0, true);
fn derivative(t: f64, y: Vector3<f64>, p: &Params) -> Vector3<f64> {
if p.1 { -y } else { y * t }
}
let y0 = Vector3::new(1.0, 0.0, 0.0);
let ode = ODE::new(&derivative, 0.0, 10.0, y0, params);
assert!((ode.f)(0.0, y0, &params) == Vector3::new(-1.0, 0.0, 0.0));
assert!(ode.y == Vector3::new(1.0, 0.0, 0.0));
assert!(ode.t == 0.0);
assert!(!ode.finished);