Fixed things to use cubic interpolation and tests pass

This commit is contained in:
Connor Johnstone
2025-10-23 17:17:22 -04:00
parent 500bbfcf86
commit bd6f3b8ee4

View File

@@ -170,12 +170,20 @@ where
let err = h * (Self::B_ERROR[0] * k[0] + Self::B_ERROR[1] * k[1] + Self::B_ERROR[2] * k[2] + Self::B_ERROR[3] * k[3]); let err = h * (Self::B_ERROR[0] * k[0] + Self::B_ERROR[1] * k[1] + Self::B_ERROR[2] * k[2] + Self::B_ERROR[3] * k[3]);
// Compute error norm scaled by tolerance // Compute error norm scaled by tolerance
let tol = self.a_tol + ode.y.abs().component_mul(&SVector::<f64, D>::from_element(self.r_tol)); let tol = self.a_tol + ode.y.abs() * self.r_tol;
let error_norm = (err.component_div(&tol)).norm(); let error_norm = (err.component_div(&tol)).norm();
// Store k values for dense output (3rd order Hermite interpolation) // Store coefficients for dense output (cubic Hermite interpolation)
// Note: k[3] can be reused as k[0] for the next step (FSAL property) // BS3 uses standard cubic Hermite interpolation with derivatives at endpoints
(next_y, Some(error_norm), Some(k)) // Store: y0, y1, f0=k[0], f1=k[3] (FSAL)
let dense_coeffs = vec![
ode.y, // y0 at start of step
next_y, // y1 at end of step
k[0], // f(t0, y0) - derivative at start
k[3], // f(t1, y1) - derivative at end (FSAL)
];
(next_y, Some(error_norm), Some(dense_coeffs))
} }
fn interpolate( fn interpolate(
@@ -189,45 +197,32 @@ where
let theta = (t - t_start) / (t_end - t_start); let theta = (t - t_start) / (t_end - t_start);
let h = t_end - t_start; let h = t_end - t_start;
// BS3 uses 3rd order Hermite interpolation // Cubic Hermite interpolation using values and derivatives at endpoints
// The formula is: y(t_start + θ*h) = y0 + h*θ*P(θ) // dense[0] = y0 (value at start)
// where P(θ) is a polynomial in θ using the k values // dense[1] = y1 (value at end)
// dense[2] = f0 (derivative at start)
// dense[3] = f1 (derivative at end)
// //
// For BS3, the interpolation formula from the original paper is: // Standard cubic Hermite formula:
// u(θ) = y0 + h*θ*(k1 + θ*((1-θ)*k2 + θ*k3)) // y(θ) = (1 + 2θ)(1-θ)²*y0 + θ²(3-2θ)*y1 + θ(1-θ)²*h*f0 + θ²(θ-1)*h*f1
// //
// This can be rewritten as: // Equivalently (Horner form):
// u(θ) = y0 + h*θ*(b1(θ)*k1 + b2(θ)*k2 + b3(θ)*k3) // y(θ) = y0 + θ*[h*f0 + θ*(-3*y0 - 2*h*f0 + 3*y1 - h*f1 + θ*(2*y0 + h*f0 - 2*y1 + h*f1))]
//
// where b1(θ) = 1, b2(θ) = θ*(1-θ), b3(θ) = θ²
//
// Actually, the correct BS3 interpolant maintains 3rd order and is:
// u(θ) = y0 + h*[θ*k1 + θ²*(3/2*k1 + 2*k2 1/2*k3) + θ³*(k1 2*k2 + k3)]
let k1 = &dense[0]; let y0 = &dense[0];
let k2 = &dense[1]; let y1 = &dense[1];
let k3 = &dense[2]; let f0 = &dense[2];
let f1 = &dense[3];
// Simplified 3rd order interpolation that matches boundary conditions
// At θ=0: u(0) = y0 ✓
// At θ=1: u(1) = y0 + h*(2/9*k1 + 1/3*k2 + 4/9*k3) = y1 ✓
//
// Using the standard Hermite cubic formula:
let theta2 = theta * theta; let theta2 = theta * theta;
let theta3 = theta2 * theta; let one_minus_theta = 1.0 - theta;
let one_minus_theta2 = one_minus_theta * one_minus_theta;
// Coefficients for 3rd order Hermite interpolation // Apply cubic Hermite interpolation formula
// These ensure continuity and 3rd order accuracy (1.0 + 2.0 * theta) * one_minus_theta2 * y0
let b1 = theta - 1.5 * theta2 + theta3; + theta2 * (3.0 - 2.0 * theta) * y1
let b2 = 2.0 * theta2 - 2.0 * theta3; + theta * one_minus_theta2 * h * f0
let b3 = -0.5 * theta2 + theta3; + theta2 * (theta - 1.0) * h * f1
// Note: We need y0, which we can recover from the solution
// But in practice, this interpolation is used within the solver
// where we know the step boundaries. For now, we use the k values directly.
//
// A simpler, still 3rd order accurate form:
dense[0] * (h * theta) + (dense[1] - dense[0]) * (h * theta2) + (dense[2] - 2.0 * dense[1] + dense[0]) * (h * theta3)
} }
} }
@@ -239,7 +234,7 @@ mod tests {
#[test] #[test]
fn test_bs3_creation() { fn test_bs3_creation() {
let bs3: BS3<1> = BS3::new(); let _bs3: BS3<1> = BS3::new();
assert_eq!(BS3::<1>::ORDER, 3); assert_eq!(BS3::<1>::ORDER, 3);
assert_eq!(BS3::<1>::STAGES, 4); assert_eq!(BS3::<1>::STAGES, 4);
assert!(BS3::<1>::ADAPTIVE); assert!(BS3::<1>::ADAPTIVE);
@@ -257,17 +252,18 @@ mod tests {
let ode = ODE::new(&derivative, 0.0, 1.0, y0, ()); let ode = ODE::new(&derivative, 0.0, 1.0, y0, ());
let bs3 = BS3::new(); let bs3 = BS3::new();
let h = 0.1; let h = 0.001; // Smaller step size for tighter tolerances
let (y_next, err, dense) = bs3.step(&ode, h); let (y_next, err, dense) = bs3.step(&ode, h);
// At t=0.1, exact solution is e^0.1 ≈ 1.105170918 // At t=0.001, exact solution is e^0.001 ≈ 1.0010005001667084
let exact = (0.1_f64).exp(); let exact = (0.001_f64).exp();
assert_relative_eq!(y_next[0], exact, max_relative = 1e-4); assert_relative_eq!(y_next[0], exact, max_relative = 1e-6);
// Error should be reasonable for h=0.1 // Error should be reasonable for h=0.001
assert!(err.is_some()); assert!(err.is_some());
assert!(err.unwrap() < 10.0); // The error estimate is scaled by tolerance, so err < 1 means step is acceptable
assert!(err.unwrap() < 1.0);
// Dense output should be provided // Dense output should be provided
assert!(dense.is_some()); assert!(dense.is_some());
@@ -285,18 +281,18 @@ mod tests {
let ode = ODE::new(&derivative, 0.0, 1.0, y0, ()); let ode = ODE::new(&derivative, 0.0, 1.0, y0, ());
let bs3 = BS3::new(); let bs3 = BS3::new();
let h = 0.1; let h = 0.001; // Smaller step size
let (_y_next, _err, dense) = bs3.step(&ode, h); let (_y_next, _err, dense) = bs3.step(&ode, h);
let dense = dense.unwrap(); let dense = dense.unwrap();
// Interpolate at midpoint // Interpolate at midpoint
let t_mid = 0.05; let t_mid = 0.0005;
let y_mid = bs3.interpolate(0.0, 0.1, &dense, t_mid); let y_mid = bs3.interpolate(0.0, 0.001, &dense, t_mid);
// Should be close to e^0.05 // Should be close to e^0.0005
let exact = (0.05_f64).exp(); let exact = (0.0005_f64).exp();
// Interpolation might be less accurate than the step itself // Cubic Hermite interpolation should be quite accurate
assert_relative_eq!(y_mid[0], exact, max_relative = 1e-3); assert_relative_eq!(y_mid[0], exact, max_relative = 1e-10);
} }
} }