From d63e8e68b900161f4080026e99b77ddc9d5a01c1 Mon Sep 17 00:00:00 2001 From: M Stoeckl Date: Sun, 7 Dec 2025 11:31:44 -0500 Subject: [PATCH] Avoid panics in Binomial BTPE sampler by using u64 values --- CHANGELOG.md | 1 + src/binomial.rs | 52 ++++++++++++++++++++++++++----------------------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d32e7a..579cce0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use `direct-minimal-versions` (#38) - Fix panic in `FisherF::new` on almost zero parameters (#39) - Fix panic in `NormalInverseGaussian::new` with very large `alpha`; this is a Value-breaking change (#40) +- Fix panic in `Binomial::sample` with `n ≥ 2^63`; this is a Value-breaking change (#43) - Error instead of producing `-inf` output for `Exp` when `lambda` is `-0.0` (#44) ## [0.5.2] diff --git a/src/binomial.rs b/src/binomial.rs index ff092ff..2e6ec82 100644 --- a/src/binomial.rs +++ b/src/binomial.rs @@ -78,7 +78,7 @@ struct Binv { struct Btpe { n: u64, p: f64, - m: i64, + m: u64, p1: f64, } @@ -168,17 +168,17 @@ impl Binomial { let npq = np * q; let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; let f_m = np + p; - let m = f64_to_i64(f_m); + let m = f64_to_u64(f_m); Method::Btpe(Btpe { n, p, m, p1 }, flipped) }; Ok(Binomial { method }) } } -/// Convert a `f64` to an `i64`, panicking on overflow. -fn f64_to_i64(x: f64) -> i64 { - assert!(x < (i64::MAX as f64)); - x as i64 +/// Convert a `f64` to a `u64`, panicking on overflow. +fn f64_to_u64(x: f64) -> u64 { + assert!(x >= 0.0 && x < (u64::MAX as f64)); + x as u64 } fn binv(binv: Binv, flipped: bool, rng: &mut R) -> u64 { @@ -211,11 +211,11 @@ fn binv(binv: Binv, flipped: bool, rng: &mut R) -> u64 { fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { // Threshold for using the squeeze algorithm. This can be freely // chosen based on performance. Ranlib and GSL use 20. - const SQUEEZE_THRESHOLD: i64 = 20; + const SQUEEZE_THRESHOLD: u64 = 20; // Step 0: Calculate constants as functions of `n` and `p`. - let n = btpe.n as f64; - let np = n * btpe.p; + let n = btpe.n; + let np = (n as f64) * btpe.p; let q = 1. - btpe.p; let npq = np * q; let f_m = np + btpe.p; @@ -244,7 +244,7 @@ fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { let p4 = p3 + c / lambda_r; // return value - let mut y: i64; + let mut y: u64; let gen_u = Uniform::new(0., p4).unwrap(); let gen_v = Uniform::new(0., 1.).unwrap(); @@ -255,7 +255,7 @@ fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { let u = gen_u.sample(rng); let mut v = gen_v.sample(rng); if !(u > p1) { - y = f64_to_i64(x_m - p1 * v + u); + y = f64_to_u64(x_m - p1 * v + u); break; } @@ -267,20 +267,21 @@ fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { if v > 1. { continue; } else { - y = f64_to_i64(x); + y = f64_to_u64(x); } } else if !(u > p3) { // Step 3: Region 3, left exponential tail. - y = f64_to_i64(x_l + v.ln() / lambda_l); - if y < 0 { + let y_tmp = x_l + v.ln() / lambda_l; + if y_tmp < 0.0 { continue; } else { + y = f64_to_u64(y_tmp); v *= (u - p2) * lambda_l; } } else { // Step 4: Region 4, right exponential tail. - y = f64_to_i64(x_r - v.ln() / lambda_r); - if y > 0 && (y as u64) > btpe.n { + y = (x_r - v.ln() / lambda_r) as u64; // `as` cast saturates + if y > btpe.n { continue; } else { v *= (u - p3) * lambda_r; @@ -290,12 +291,12 @@ fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { // Step 5: Acceptance/rejection comparison. // Step 5.0: Test for appropriate method of evaluating f(y). - let k = (y - m).abs(); + let k = y.abs_diff(m); if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { // Step 5.1: Evaluate f(y) via the recursive relationship. Start the // search from the mode. let s = btpe.p / q; - let a = s * (n + 1.); + let a = s * (n as f64 + 1.); let mut f = 1.0; match m.cmp(&y) { Ordering::Less => { @@ -343,18 +344,23 @@ fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { // Step 5.3: Final acceptance/rejection test. let x1 = (y + 1) as f64; let f1 = (m + 1) as f64; - let z = (f64_to_i64(n) + 1 - m) as f64; - let w = (f64_to_i64(n) - y + 1) as f64; + let z = ((n - m) + 1) as f64; + let w = ((n - y) + 1) as f64; fn stirling(a: f64) -> f64 { let a2 = a * a; (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. } + let y_sub_m = if y > m { + (y - m) as f64 + } else { + -((m - y) as f64) + }; if alpha > x_m * (f1 / x1).ln() - + (n - (m as f64) + 0.5) * (z / w).ln() - + ((y - m) as f64) * (w * btpe.p / (x1 * q)).ln() + + (((n - m) as f64) + 0.5) * (z / w).ln() + + y_sub_m * (w * btpe.p / (x1 * q)).ln() // We use the signs from the GSL implementation, which are // different than the ones in the reference. According to // the GSL authors, the new signs were verified to be @@ -370,8 +376,6 @@ fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { break; } - assert!(y >= 0); - let y = y as u64; if flipped { btpe.n - y } else { y } }