diff --git a/crates/RustQuant_math/src/interpolation/lagrange_interpolator.rs b/crates/RustQuant_math/src/interpolation/lagrange_interpolator.rs new file mode 100644 index 00000000..d0ff8f03 --- /dev/null +++ b/crates/RustQuant_math/src/interpolation/lagrange_interpolator.rs @@ -0,0 +1,190 @@ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// RustQuant: A Rust library for quantitative finance tools. +// Copyright (C) 2023 https://github.com/avhz +// Dual licensed under Apache 2.0 and MIT. +// See: +// - LICENSE-APACHE.md +// - LICENSE-MIT.md +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +//! Module containing functionality for interpolation. + +use crate::interpolation::{InterpolationIndex, InterpolationValue, Interpolator}; +use RustQuant_error::RustQuantError; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// STRUCTS & ENUMS +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Linear Interpolator. +pub struct LagrangeInterpolator +where + IndexType: InterpolationIndex, + ValueType: InterpolationValue, +{ + /// X-axis values for the interpolator. + pub xs: Vec, + + /// Y-axis values for the interpolator. + pub ys: Vec, + + /// Whether the interpolator has been fitted. + pub fitted: bool, +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// IMPLEMENTATIONS, FUNCTIONS, AND MACROS +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +impl LagrangeInterpolator +where + IndexType: InterpolationIndex, + ValueType: InterpolationValue, +{ + /// Create a new LagrangeInterpolator. + /// + /// # Errors + /// - `RustQuantError::UnequalLength` if ```xs.length() != ys.length()```. + /// + /// # Panics + /// Panics if NaN is in the index. + pub fn new( + xs: Vec, + ys: Vec, + ) -> Result, RustQuantError> { + if xs.len() != ys.len() { + return Err(RustQuantError::UnequalLength); + } + + let mut tmp: Vec<_> = xs.into_iter().zip(ys).collect(); + + tmp.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + + let (xs, ys): (Vec, Vec) = tmp.into_iter().unzip(); + + Ok(Self { + xs, + ys, + fitted: false, + }) + } + + fn lagrange_basis(&self, point: IndexType, node: IndexType, index: usize) -> ValueType { + let mut basis: ValueType = ValueType::one(); + for (i, x) in self.xs.iter().enumerate() { + if i != index { + basis *= (point - *x) / (node - *x); + } + } + basis + } + + fn lagrange_polynomial(&self, point: IndexType) -> ValueType { + let mut polynomial: ValueType = ValueType::zero(); + for (i, (x, y)) in self.xs.iter().zip(&self.ys).enumerate() { + polynomial += *y * self.lagrange_basis(point, *x, i); + + } + polynomial + } +} + +impl Interpolator + for LagrangeInterpolator +where + IndexType: InterpolationIndex, + ValueType: InterpolationValue, +{ + fn fit(&mut self) -> Result<(), RustQuantError> { + self.fitted = true; + Ok(()) + } + + fn range(&self) -> (IndexType, IndexType) { + (*self.xs.first().unwrap(), *self.xs.last().unwrap()) + } + + fn add_point(&mut self, point: (IndexType, ValueType)) { + let idx = self.xs.partition_point(|&x| x < point.0); + self.xs.insert(idx, point.0); + self.ys.insert(idx, point.1); + } + + fn interpolate(&self, point: IndexType) -> Result { + let range = self.range(); + if point.partial_cmp(&range.0).unwrap() == std::cmp::Ordering::Less + || point.partial_cmp(&range.1).unwrap() == std::cmp::Ordering::Greater + { + return Err(RustQuantError::OutsideOfRange); + } + if let Ok(idx) = self + .xs + .binary_search_by(|p| p.partial_cmp(&point).expect("Cannot compare values.")) + { + return Ok(self.ys[idx]); + } + + Ok(self.lagrange_polynomial(point)) + } +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Unit tests +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +#[cfg(test)] +mod tests_lagrange_interpolation { + use super::*; + use RustQuant_utils::{assert_approx_equal, RUSTQUANT_EPSILON}; + + #[test] + fn test_lagrange_interpolation() { + let xs: Vec = vec![0., 1., 2., 3., 4.]; + let ys: Vec = vec![1., 2., 4., 8., 16.]; + + let mut interpolator = LagrangeInterpolator::new(xs, ys).unwrap(); + let _ = interpolator.fit(); + + assert_approx_equal!( + 5.6484375, + interpolator.interpolate(2.5).unwrap(), + RUSTQUANT_EPSILON + ); + } + + #[test] + fn test_lagrange_interpolation_dates() { + let now: time::OffsetDateTime = time::OffsetDateTime::now_utc(); + + let xs: Vec = vec![ + now, + now + time::Duration::days(1), + now + time::Duration::days(2), + now + time::Duration::days(3), + now + time::Duration::days(4), + ]; + let ys: Vec = vec![1., 2., 4., 8., 16.]; + + let mut interpolator: LagrangeInterpolator = LagrangeInterpolator::new(xs.clone(), ys).unwrap(); + let _ = interpolator.fit(); + + assert_approx_equal!( + 5.6484375, + interpolator + .interpolate(xs[2] + time::Duration::hours(12)) + .unwrap(), + RUSTQUANT_EPSILON + ); + } + + #[test] + fn test_linear_interpolation_out_of_range() { + let xs: Vec = vec![1., 2., 3., 4., 5.]; + let ys: Vec = vec![1., 2., 3., 4., 5.]; + + let mut interpolator: LagrangeInterpolator = LagrangeInterpolator::new(xs, ys).unwrap(); + let _ = interpolator.fit(); + + assert!(interpolator.interpolate(6.).is_err()); + } +} diff --git a/crates/RustQuant_math/src/interpolation/mod.rs b/crates/RustQuant_math/src/interpolation/mod.rs index bec74e28..56653b01 100644 --- a/crates/RustQuant_math/src/interpolation/mod.rs +++ b/crates/RustQuant_math/src/interpolation/mod.rs @@ -19,6 +19,9 @@ pub use exponential_interpolator::*; pub mod b_splines; pub use b_splines::*; +pub mod lagrange_interpolator; +pub use lagrange_interpolator::*; + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~