diff --git a/Cargo.lock b/Cargo.lock index 1a5fdbd..a2f2c69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,15 +2,146 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "datasketches" version = "0.1.0" dependencies = [ + "byteorder", + "googletest", "mur3", ] +[[package]] +name = "googletest" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06597b7d02ee58b9a37f522785ac15b9e18c6b178747c4439a6c03fbb35ea753" +dependencies = [ + "googletest_macro", + "num-traits", + "regex", + "rustversion", +] + +[[package]] +name = "googletest_macro" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31d9f07c9c19b855faebf71637be3b43f8e13a518aece5d61a3beee7710b4ef" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + [[package]] name = "mur3" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97af489e1e21b68de4c390ecca6703318bc1aa16e9733bcb62c089b73c6fbb1b" + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "proc-macro2" +version = "1.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "regex" +version = "1.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "syn" +version = "2.0.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" diff --git a/Cargo.toml b/Cargo.toml index ad4eeb6..90961f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,8 +35,12 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [dependencies] +byteorder = { version = "1.5.0" } mur3 = { version = "0.1.0" } +[dev-dependencies] +googletest = { version = "0.14.2" } + [lints.rust] unknown_lints = "deny" unsafe_code = "deny" diff --git a/src/lib.rs b/src/lib.rs index 07ace55..20f5940 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,3 +28,4 @@ pub mod error; pub mod hll; +pub mod tdigest; diff --git a/src/tdigest/mod.rs b/src/tdigest/mod.rs new file mode 100644 index 0000000..ad9ca42 --- /dev/null +++ b/src/tdigest/mod.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! T-Digest implementation for estimating quantiles and ranks. +//! +//! The implementation in this library is based on the MergingDigest described in +//! [Computing Extremely Accurate Quantiles Using t-Digests][paper] by Ted Dunning and Otmar Ertl. +//! +//! The implementation in this library has a few differences from the reference implementation +//! associated with that paper: +//! +//! * Merge does not modify the input +//! * Deserialization similar to other sketches in this library, although reading the reference +//! implementation format is supported +//! +//! Unlike all other algorithms in the library, t-digest is empirical and has no mathematical +//! basis for estimating its error and its results are dependent on the input data. However, +//! for many common data distributions, it can produce excellent results. t-digest also operates +//! only on numeric data and, unlike the quantiles family algorithms in the library which return +//! quantile approximations from the input domain, t-digest interpolates values and will hold and +//! return data points not seen in the input. +//! +//! The closest alternative to t-digest in this library is REQ sketch. It prioritizes one chosen +//! side of the rank domain: either low rank accuracy or high rank accuracy. t-digest (in this +//! implementation) prioritizes both ends of the rank domain and has lower accuracy towards the +//! middle of the rank domain (median). +//! +//! Measurements show that t-digest is slightly biased (tends to underestimate low ranks and +//! overestimate high ranks), while still doing very well close to the extremes. The effect seems +//! to be more pronounced with more input values. +//! +//! For more information on the performance characteristics, see the +//! [Datasketches page on t-digest](https://datasketches.apache.org/docs/tdigest/tdigest.html). +//! +//! [paper]: https://arxiv.org/abs/1902.04023 + +mod serialization; + +mod sketch; +pub use self::sketch::TDigest; +pub use self::sketch::TDigestMut; diff --git a/src/tdigest/serialization.rs b/src/tdigest/serialization.rs new file mode 100644 index 0000000..e5b9788 --- /dev/null +++ b/src/tdigest/serialization.rs @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(super) const PREAMBLE_LONGS_EMPTY_OR_SINGLE: u8 = 1; +pub(super) const PREAMBLE_LONGS_MULTIPLE: u8 = 2; +pub(super) const SERIAL_VERSION: u8 = 1; +pub(super) const TDIGEST_FAMILY_ID: u8 = 20; +pub(super) const FLAGS_IS_EMPTY: u8 = 1 << 0; +pub(super) const FLAGS_IS_SINGLE_VALUE: u8 = 1 << 1; +pub(super) const FLAGS_REVERSE_MERGE: u8 = 1 << 2; +/// the format of the reference implementation is using double (f64) precision +pub(super) const COMPAT_DOUBLE: u32 = 1; +/// the format of the reference implementation is using float (f32) precision +pub(super) const COMPAT_FLOAT: u32 = 2; diff --git a/src/tdigest/sketch.rs b/src/tdigest/sketch.rs new file mode 100644 index 0000000..7f125d9 --- /dev/null +++ b/src/tdigest/sketch.rs @@ -0,0 +1,1170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::SerdeError; +use crate::tdigest::serialization::*; +use byteorder::{BE, LE, ReadBytesExt}; +use std::cmp::Ordering; +use std::convert::identity; +use std::io::Cursor; +use std::num::NonZeroU64; + +/// The default value of K if one is not specified. +const DEFAULT_K: u16 = 200; +/// Multiplier for buffer size relative to centroids capacity. +const BUFFER_MULTIPLIER: usize = 4; +/// Default weight for single values. +const DEFAULT_WEIGHT: NonZeroU64 = NonZeroU64::new(1).unwrap(); + +/// T-Digest sketch for estimating quantiles and ranks. +/// +/// See the [tdigest module level documentation](crate::tdigest) for more. +#[derive(Debug, Clone)] +pub struct TDigestMut { + k: u16, + + reverse_merge: bool, + min: f64, + max: f64, + + centroids: Vec, + centroids_weight: u64, + centroids_capacity: usize, + buffer: Vec, +} + +impl Default for TDigestMut { + fn default() -> Self { + TDigestMut::new(DEFAULT_K) + } +} + +impl TDigestMut { + /// Creates a tdigest instance with the given value of k. + /// + /// # Panics + /// + /// If k is less than 10 + pub fn new(k: u16) -> Self { + Self::make( + k, + false, + f64::INFINITY, + f64::NEG_INFINITY, + vec![], + 0, + vec![], + ) + } + + // for deserialization + fn make( + k: u16, + reverse_merge: bool, + min: f64, + max: f64, + mut centroids: Vec, + centroids_weight: u64, + mut buffer: Vec, + ) -> Self { + assert!(k >= 10, "k must be at least 10"); + + let fudge = if k < 30 { 30 } else { 10 }; + let centroids_capacity = (k as usize * 2) + fudge; + + centroids.reserve(centroids_capacity); + buffer.reserve(centroids_capacity * BUFFER_MULTIPLIER); + + TDigestMut { + k, + reverse_merge, + min, + max, + centroids, + centroids_weight, + centroids_capacity, + buffer, + } + } + + /// Update this TDigest with the given value. + /// + /// [f64::NAN], [f64::INFINITY], and [f64::NEG_INFINITY] values are ignored. + pub fn update(&mut self, value: f64) { + if value.is_nan() || value.is_infinite() { + return; + } + + if self.buffer.len() == self.centroids_capacity * BUFFER_MULTIPLIER { + self.compress(); + } + + self.buffer.push(value); + self.min = self.min.min(value); + self.max = self.max.max(value); + } + + /// Returns parameter k (compression) that was used to configure this TDigest. + pub fn k(&self) -> u16 { + self.k + } + + /// Returns true if TDigest has not seen any data. + pub fn is_empty(&self) -> bool { + self.centroids.is_empty() && self.buffer.is_empty() + } + + /// Returns minimum value seen by TDigest; `None` if TDigest is empty. + pub fn min_value(&self) -> Option { + if self.is_empty() { + None + } else { + Some(self.min) + } + } + + /// Returns maximum value seen by TDigest; `None` if TDigest is empty. + pub fn max_value(&self) -> Option { + if self.is_empty() { + None + } else { + Some(self.max) + } + } + + /// Returns total weight. + pub fn total_weight(&self) -> u64 { + self.centroids_weight + self.buffer.len() as u64 + } + + /// Merge the given TDigest into this one + pub fn merge(&mut self, other: &TDigestMut) { + if other.is_empty() { + return; + } + + let mut tmp = Vec::with_capacity( + self.centroids.len() + self.buffer.len() + other.centroids.len() + other.buffer.len(), + ); + for &v in &self.buffer { + tmp.push(Centroid { + mean: v, + weight: DEFAULT_WEIGHT, + }); + } + for &v in &other.buffer { + tmp.push(Centroid { + mean: v, + weight: DEFAULT_WEIGHT, + }); + } + for &c in &other.centroids { + tmp.push(c); + } + self.do_merge(tmp, self.buffer.len() as u64 + other.total_weight()) + } + + /// Freezes this TDigest into an immutable one. + pub fn freeze(mut self) -> TDigest { + self.compress(); + TDigest { + k: self.k, + reverse_merge: self.reverse_merge, + min: self.min, + max: self.max, + centroids: self.centroids, + centroids_weight: self.centroids_weight, + } + } + + fn view(&mut self) -> TDigestView<'_> { + self.compress(); // side effect + TDigestView { + min: self.min, + max: self.max, + centroids: &self.centroids, + centroids_weight: self.centroids_weight, + } + } + + /// Returns an approximation to the Cumulative Distribution Function (CDF), which is the + /// cumulative analog of the PMF, of the input stream given a set of split points. + /// + /// # Arguments + /// + /// * `split_points`: An array of _m_ unique, monotonically increasing values that divide the + /// input domain into _m+1_ consecutive disjoint intervals. + /// + /// # Returns + /// + /// An array of m+1 doubles, which are a consecutive approximation to the CDF of the input + /// stream given the split points. The value at array position j of the returned CDF array + /// is the sum of the returned values in positions 0 through j of the returned PMF array. + /// This can be viewed as array of ranks of the given split points plus one more value that + /// is always 1. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If `split_points` is not unique, not monotonically increasing, or contains `NaN` values. + pub fn cdf(&mut self, split_points: &[f64]) -> Option> { + check_split_points(split_points); + + if self.is_empty() { + return None; + } + + self.view().cdf(split_points) + } + + /// Returns an approximation to the Probability Mass Function (PMF) of the input stream + /// given a set of split points. + /// + /// # Arguments + /// + /// * `split_points`: An array of _m_ unique, monotonically increasing values that divide the + /// input domain into _m+1_ consecutive disjoint intervals (bins). + /// + /// # Returns + /// + /// An array of m+1 doubles each of which is an approximation to the fraction of the input + /// stream values (the mass) that fall into one of those intervals. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If `split_points` is not unique, not monotonically increasing, or contains `NaN` values. + pub fn pmf(&mut self, split_points: &[f64]) -> Option> { + check_split_points(split_points); + + if self.is_empty() { + return None; + } + + self.view().pmf(split_points) + } + + /// Compute approximate normalized rank (from 0 to 1 inclusive) of the given value. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If the value is `NaN`. + pub fn rank(&mut self, value: f64) -> Option { + assert!(!value.is_nan(), "value must not be NaN"); + + if self.is_empty() { + return None; + } + if value < self.min { + return Some(0.0); + } + if value > self.max { + return Some(1.0); + } + // one centroid and value == min == max + if self.centroids.len() + self.buffer.len() == 1 { + return Some(0.5); + } + + self.view().rank(value) + } + + /// Compute approximate quantile value corresponding to the given normalized rank. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If rank is not in [0.0, 1.0]. + pub fn quantile(&mut self, rank: f64) -> Option { + assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]"); + + if self.is_empty() { + return None; + } + + self.view().quantile(rank) + } + + /// Serializes this TDigest to bytes. + pub fn serialize(&mut self) -> Vec { + self.compress(); + + let mut total_size = 0; + if self.is_empty() || self.is_single_value() { + // 1 byte preamble + // + 1 byte serial version + // + 1 byte family + // + 2 bytes k + // + 1 byte flags + // + 2 bytes unused + total_size += size_of::(); + } else { + // all of the above + // + 4 bytes num centroids + // + 4 bytes num buffered + total_size += size_of::() * 2; + } + if self.is_empty() { + // nothing more + } else if self.is_single_value() { + // + 8 bytes single value + total_size += size_of::(); + } else { + // + 8 bytes min + // + 8 bytes max + total_size += size_of::() * 2; + // + (8+8) bytes per centroid + total_size += self.centroids.len() * (size_of::() + size_of::()); + } + + let mut bytes = Vec::with_capacity(total_size); + bytes.push(match self.total_weight() { + 0 => PREAMBLE_LONGS_EMPTY_OR_SINGLE, + 1 => PREAMBLE_LONGS_EMPTY_OR_SINGLE, + _ => PREAMBLE_LONGS_MULTIPLE, + }); + bytes.push(SERIAL_VERSION); + bytes.push(TDIGEST_FAMILY_ID); + bytes.extend_from_slice(&self.k.to_le_bytes()); + bytes.push({ + let mut flags = 0; + if self.is_empty() { + flags |= FLAGS_IS_EMPTY; + } + if self.is_single_value() { + flags |= FLAGS_IS_SINGLE_VALUE; + } + if self.reverse_merge { + flags |= FLAGS_REVERSE_MERGE; + } + flags + }); + bytes.extend_from_slice(&0u16.to_le_bytes()); // unused + if self.is_empty() { + return bytes; + } + if self.is_single_value() { + bytes.extend_from_slice(&self.min.to_le_bytes()); + return bytes; + } + bytes.extend_from_slice(&(self.centroids.len() as u32).to_le_bytes()); + bytes.extend_from_slice(&0u32.to_le_bytes()); // unused + bytes.extend_from_slice(&self.min.to_le_bytes()); + bytes.extend_from_slice(&self.max.to_le_bytes()); + for centroid in &self.centroids { + bytes.extend_from_slice(¢roid.mean.to_le_bytes()); + bytes.extend_from_slice(¢roid.weight.get().to_le_bytes()); + } + bytes + } + + /// Deserializes a TDigest from bytes. + /// + /// Supports reading compact format with (float, int) centroids as opposed to (double, long) to + /// represent (mean, weight). [^1] + /// + /// Supports reading format of the reference implementation (auto-detected) [^2]. + /// + /// [^1]: This is to support reading the `tdigest` format from the C++ implementation. + /// [^2]: + pub fn deserialize(bytes: &[u8], is_f32: bool) -> Result { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> SerdeError { + move |_| SerdeError::InsufficientData(tag.to_string()) + } + + let mut cursor = Cursor::new(bytes); + + let preamble_longs = cursor.read_u8().map_err(make_error("preamble_longs"))?; + let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; + let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + if family_id != TDIGEST_FAMILY_ID { + if preamble_longs == 0 && serial_version == 0 && family_id == 0 { + return Self::deserialize_compat(bytes); + } + return Err(SerdeError::InvalidFamily(format!( + "expected {} (TDigest), got {}", + TDIGEST_FAMILY_ID, family_id + ))); + } + if serial_version != SERIAL_VERSION { + return Err(SerdeError::UnsupportedVersion(format!( + "expected {}, got {}", + SERIAL_VERSION, serial_version + ))); + } + let k = cursor.read_u16::().map_err(make_error("k"))?; + if k < 10 { + return Err(SerdeError::InvalidParameter(format!( + "k must be at least 10, got {k}" + ))); + } + let flags = cursor.read_u8().map_err(make_error("flags"))?; + let is_empty = (flags & FLAGS_IS_EMPTY) != 0; + let is_single_value = (flags & FLAGS_IS_SINGLE_VALUE) != 0; + let expected_preamble_longs = if is_empty || is_single_value { + PREAMBLE_LONGS_EMPTY_OR_SINGLE + } else { + PREAMBLE_LONGS_MULTIPLE + }; + if preamble_longs != expected_preamble_longs { + return Err(SerdeError::MalformedData(format!( + "expected preamble_longs to be {}, got {}", + expected_preamble_longs, preamble_longs + ))); + } + cursor.read_u16::().map_err(make_error(""))?; // unused + if is_empty { + return Ok(TDigestMut::new(k)); + } + + let reverse_merge = (flags & FLAGS_REVERSE_MERGE) != 0; + if is_single_value { + let value = if is_f32 { + cursor + .read_f32::() + .map_err(make_error("single_value"))? as f64 + } else { + cursor + .read_f64::() + .map_err(make_error("single_value"))? + }; + check_non_nan(value, "single_value")?; + check_non_infinite(value, "single_value")?; + return Ok(TDigestMut::make( + k, + reverse_merge, + value, + value, + vec![Centroid { + mean: value, + weight: DEFAULT_WEIGHT, + }], + 1, + vec![], + )); + } + let num_centroids = cursor + .read_u32::() + .map_err(make_error("num_centroids"))? as usize; + let num_buffered = cursor + .read_u32::() + .map_err(make_error("num_buffered"))? as usize; + let (min, max) = if is_f32 { + ( + cursor.read_f32::().map_err(make_error("min"))? as f64, + cursor.read_f32::().map_err(make_error("max"))? as f64, + ) + } else { + ( + cursor.read_f64::().map_err(make_error("min"))?, + cursor.read_f64::().map_err(make_error("max"))?, + ) + }; + check_non_nan(min, "min")?; + check_non_nan(max, "max")?; + let mut centroids = Vec::with_capacity(num_centroids); + let mut centroids_weight = 0u64; + for _ in 0..num_centroids { + let (mean, weight) = if is_f32 { + ( + cursor.read_f32::().map_err(make_error("mean"))? as f64, + cursor.read_u32::().map_err(make_error("weight"))? as u64, + ) + } else { + ( + cursor.read_f64::().map_err(make_error("mean"))?, + cursor.read_u64::().map_err(make_error("weight"))?, + ) + }; + check_non_nan(mean, "centroid mean")?; + check_non_infinite(mean, "centroid")?; + let weight = check_nonzero(weight, "centroid weight")?; + centroids_weight += weight.get(); + centroids.push(Centroid { mean, weight }); + } + let mut buffer = Vec::with_capacity(num_buffered); + for _ in 0..num_buffered { + let value = if is_f32 { + cursor + .read_f32::() + .map_err(make_error("buffered_value"))? as f64 + } else { + cursor + .read_f64::() + .map_err(make_error("buffered_value"))? + }; + check_non_nan(value, "buffered_value mean")?; + check_non_infinite(value, "buffered_value mean")?; + buffer.push(value); + } + Ok(TDigestMut::make( + k, + reverse_merge, + min, + max, + centroids, + centroids_weight, + buffer, + )) + } + + // compatibility with the format of the reference implementation + // default byte order of ByteBuffer is used there, which is big endian + fn deserialize_compat(bytes: &[u8]) -> Result { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> SerdeError { + move |_| SerdeError::InsufficientData(format!("{tag} in compat format")) + } + + let mut cursor = Cursor::new(bytes); + + let ty = cursor.read_u32::().map_err(make_error("type"))?; + match ty { + COMPAT_DOUBLE => { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> SerdeError { + move |_| SerdeError::InsufficientData(format!("{tag} in compat double format")) + } + // compatibility with asBytes() + let min = cursor.read_f64::().map_err(make_error("min"))?; + let max = cursor.read_f64::().map_err(make_error("max"))?; + check_non_nan(min, "min in compat double format")?; + check_non_nan(max, "max in compat double format")?; + let k = cursor.read_f64::().map_err(make_error("k"))? as u16; + if k < 10 { + return Err(SerdeError::InvalidParameter(format!( + "k must be at least 10, got {k} in compat double format" + ))); + } + let num_centroids = cursor + .read_u32::() + .map_err(make_error("num_centroids"))? + as usize; + let mut total_weight = 0u64; + let mut centroids = Vec::with_capacity(num_centroids); + for _ in 0..num_centroids { + let weight = cursor.read_f64::().map_err(make_error("weight"))? as u64; + let mean = cursor.read_f64::().map_err(make_error("mean"))?; + let weight = check_nonzero(weight, "centroid weight in compat double format")?; + check_non_nan(mean, "centroid mean in compat double format")?; + check_non_infinite(mean, "centroid mean in compat double format")?; + total_weight += weight.get(); + centroids.push(Centroid { mean, weight }); + } + Ok(TDigestMut::make( + k, + false, + min, + max, + centroids, + total_weight, + vec![], + )) + } + COMPAT_FLOAT => { + fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> SerdeError { + move |_| SerdeError::InsufficientData(format!("{tag} in compat float format")) + } + // COMPAT_FLOAT: compatibility with asSmallBytes() + // reference implementation uses doubles for min and max + let min = cursor.read_f64::().map_err(make_error("min"))?; + let max = cursor.read_f64::().map_err(make_error("max"))?; + check_non_nan(min, "min in compat float format")?; + check_non_nan(max, "max in compat float format")?; + let k = cursor.read_f32::().map_err(make_error("k"))? as u16; + if k < 10 { + return Err(SerdeError::InvalidParameter(format!( + "k must be at least 10, got {k} in compat float format" + ))); + } + // reference implementation stores capacities of the array of centroids and the + // buffer as shorts they can be derived from k in the constructor + cursor.read_u32::().map_err(make_error(""))?; + let num_centroids = cursor + .read_u16::() + .map_err(make_error("num_centroids"))? + as usize; + let mut total_weight = 0u64; + let mut centroids = Vec::with_capacity(num_centroids); + for _ in 0..num_centroids { + let weight = cursor.read_f32::().map_err(make_error("weight"))? as u64; + let mean = cursor.read_f32::().map_err(make_error("mean"))? as f64; + let weight = check_nonzero(weight, "centroid weight in compat float format")?; + check_non_nan(mean, "centroid mean in compat float format")?; + check_non_infinite(mean, "centroid mean in compat float format")?; + total_weight += weight.get(); + centroids.push(Centroid { mean, weight }); + } + Ok(TDigestMut::make( + k, + false, + min, + max, + centroids, + total_weight, + vec![], + )) + } + ty => Err(SerdeError::InvalidParameter(format!( + "unknown TDigest compat type {ty}", + ))), + } + } + + fn is_single_value(&self) -> bool { + self.total_weight() == 1 + } + + /// Process buffered values and merge centroids if needed. + fn compress(&mut self) { + if self.buffer.is_empty() { + return; + } + let mut tmp = Vec::with_capacity(self.buffer.len() + self.centroids.len()); + for &v in &self.buffer { + tmp.push(Centroid { + mean: v, + weight: DEFAULT_WEIGHT, + }); + } + self.do_merge(tmp, self.buffer.len() as u64) + } + + /// Merges the given buffer of centroids into this TDigest. + /// + /// # Contract + /// + /// * `buffer` must have at least one centroid. + /// * `buffer` is generated from `self.buffer`, and thus: + /// * No `NAN` values are present in `buffer`. + /// * We should clear `self.buffer` after merging. + fn do_merge(&mut self, mut buffer: Vec, weight: u64) { + buffer.extend(std::mem::take(&mut self.centroids)); + buffer.sort_by(centroid_cmp); + if self.reverse_merge { + buffer.reverse(); + } + self.centroids_weight += weight; + + let mut num_centroids = 0; + let len = buffer.len(); + self.centroids.push(buffer[0]); + num_centroids += 1; + let mut current = 1; + let mut weight_so_far = 0.; + while current < len { + let c = buffer[current]; + let proposed_weight = self.centroids[num_centroids - 1].weight() + c.weight(); + let mut add_this = false; + if (current != 1) && (current != (len - 1)) { + let centroids_weight = self.centroids_weight as f64; + let q0 = weight_so_far / centroids_weight; + let q2 = (weight_so_far + proposed_weight) / centroids_weight; + let normalizer = scale_function::normalizer((2 * self.k) as f64, centroids_weight); + add_this = proposed_weight + <= (centroids_weight + * scale_function::max(q0, normalizer) + .min(scale_function::max(q2, normalizer))); + } + if add_this { + // merge into existing centroid + self.centroids[num_centroids - 1].add(c); + } else { + // copy to a new centroid + weight_so_far += self.centroids[num_centroids - 1].weight(); + self.centroids.push(c); + num_centroids += 1; + } + current += 1; + } + + if self.reverse_merge { + self.centroids.reverse(); + } + self.min = self.min.min(self.centroids[0].mean); + self.max = self.max.max(self.centroids[num_centroids - 1].mean); + self.reverse_merge = !self.reverse_merge; + self.buffer.clear(); + } +} + +/// Immutable (frozen) T-Digest sketch for estimating quantiles and ranks. +/// +/// See the [module documentation](super) for more details. +pub struct TDigest { + k: u16, + + reverse_merge: bool, + min: f64, + max: f64, + + centroids: Vec, + centroids_weight: u64, +} + +impl TDigest { + /// Returns parameter k (compression) that was used to configure this TDigest. + pub fn k(&self) -> u16 { + self.k + } + + /// Returns true if TDigest has not seen any data. + pub fn is_empty(&self) -> bool { + self.centroids.is_empty() + } + + /// Returns minimum value seen by TDigest; `None` if TDigest is empty. + pub fn min_value(&self) -> Option { + if self.is_empty() { + None + } else { + Some(self.min) + } + } + + /// Returns maximum value seen by TDigest; `None` if TDigest is empty. + pub fn max_value(&self) -> Option { + if self.is_empty() { + None + } else { + Some(self.max) + } + } + + /// Returns total weight. + pub fn total_weight(&self) -> u64 { + self.centroids_weight + } + + fn view(&self) -> TDigestView<'_> { + TDigestView { + min: self.min, + max: self.max, + centroids: &self.centroids, + centroids_weight: self.centroids_weight, + } + } + + /// Returns an approximation to the Cumulative Distribution Function (CDF), which is the + /// cumulative analog of the PMF, of the input stream given a set of split points. + /// + /// # Arguments + /// + /// * `split_points`: An array of _m_ unique, monotonically increasing values that divide the + /// input domain into _m+1_ consecutive disjoint intervals. + /// + /// # Returns + /// + /// An array of m+1 doubles, which are a consecutive approximation to the CDF of the input + /// stream given the split points. The value at array position j of the returned CDF array + /// is the sum of the returned values in positions 0 through j of the returned PMF array. + /// This can be viewed as array of ranks of the given split points plus one more value that + /// is always 1. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If `split_points` is not unique, not monotonically increasing, or contains `NaN` values. + pub fn cdf(&self, split_points: &[f64]) -> Option> { + self.view().cdf(split_points) + } + + /// Returns an approximation to the Probability Mass Function (PMF) of the input stream + /// given a set of split points. + /// + /// # Arguments + /// + /// * `split_points`: An array of _m_ unique, monotonically increasing values that divide the + /// input domain into _m+1_ consecutive disjoint intervals (bins). + /// + /// # Returns + /// + /// An array of m+1 doubles each of which is an approximation to the fraction of the input + /// stream values (the mass) that fall into one of those intervals. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If `split_points` is not unique, not monotonically increasing, or contains `NaN` values. + pub fn pmf(&self, split_points: &[f64]) -> Option> { + self.view().pmf(split_points) + } + + /// Compute approximate normalized rank (from 0 to 1 inclusive) of the given value. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If the value is `NaN`. + pub fn rank(&self, value: f64) -> Option { + assert!(!value.is_nan(), "value must not be NaN"); + self.view().rank(value) + } + + /// Compute approximate quantile value corresponding to the given normalized rank. + /// + /// Returns `None` if TDigest is empty. + /// + /// # Panics + /// + /// If rank is not in [0.0, 1.0]. + pub fn quantile(&self, rank: f64) -> Option { + assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]"); + self.view().quantile(rank) + } + + /// Converts this immutable TDigest into a mutable one. + pub fn unfreeze(self) -> TDigestMut { + TDigestMut::make( + self.k, + self.reverse_merge, + self.min, + self.max, + self.centroids, + self.centroids_weight, + vec![], + ) + } +} + +struct TDigestView<'a> { + min: f64, + max: f64, + centroids: &'a [Centroid], + centroids_weight: u64, +} + +impl TDigestView<'_> { + fn pmf(&self, split_points: &[f64]) -> Option> { + let mut buckets = self.cdf(split_points)?; + for i in (1..buckets.len()).rev() { + buckets[i] -= buckets[i - 1]; + } + Some(buckets) + } + + fn cdf(&self, split_points: &[f64]) -> Option> { + check_split_points(split_points); + + if self.centroids.is_empty() { + return None; + } + + let mut ranks = Vec::with_capacity(split_points.len() + 1); + for &p in split_points { + match self.rank(p) { + Some(rank) => ranks.push(rank), + None => unreachable!("checked non-empty above"), + } + } + ranks.push(1.0); + Some(ranks) + } + + fn rank(&self, value: f64) -> Option { + debug_assert!(!value.is_nan(), "value must not be NaN"); + + if self.centroids.is_empty() { + return None; + } + if value < self.min { + return Some(0.0); + } + if value > self.max { + return Some(1.0); + } + // one centroid and value == min == max + if self.centroids.len() == 1 { + return Some(0.5); + } + + let centroids_weight = self.centroids_weight as f64; + let num_centroids = self.centroids.len(); + + // left tail + let first_mean = self.centroids[0].mean; + if value < first_mean { + if first_mean - self.min > 0. { + return Some(if value == self.min { + 0.5 / centroids_weight + } else { + 1. + (((value - self.min) / (first_mean - self.min)) + * ((self.centroids[0].weight() / 2.) - 1.)) + }); + } + return Some(0.); // should never happen + } + + // right tail + let last_mean = self.centroids[num_centroids - 1].mean; + if value > last_mean { + if self.max - last_mean > 0. { + return Some(if value == self.max { + 1. - (0.5 / centroids_weight) + } else { + 1.0 - ((1.0 + + (((self.max - value) / (self.max - last_mean)) + * ((self.centroids[num_centroids - 1].weight() / 2.) - 1.))) + / centroids_weight) + }); + } + return Some(1.); // should never happen + } + + let mut lower = self + .centroids + .binary_search_by(|c| centroid_lower_bound(c, value)) + .unwrap_or_else(identity); + assert_ne!(lower, num_centroids, "get_rank: lower == end"); + let mut upper = self + .centroids + .binary_search_by(|c| centroid_upper_bound(c, value)) + .unwrap_or_else(identity); + assert_ne!(upper, 0, "get_rank: upper == begin"); + if value < self.centroids[lower].mean { + lower -= 1; + } + if (upper == num_centroids) || (self.centroids[upper - 1].mean >= value) { + upper -= 1; + } + + let mut weight_below = 0.; + let mut i = 0; + while i < lower { + weight_below += self.centroids[i].weight(); + i += 1; + } + weight_below += self.centroids[lower].weight() / 2.; + + let mut weight_delta = 0.; + while i < upper { + weight_delta += self.centroids[i].weight(); + i += 1; + } + weight_delta -= self.centroids[lower].weight() / 2.; + weight_delta += self.centroids[upper].weight() / 2.; + Some( + if self.centroids[upper].mean - self.centroids[lower].mean > 0. { + (weight_below + + (weight_delta * (value - self.centroids[lower].mean) + / (self.centroids[upper].mean - self.centroids[lower].mean))) + / centroids_weight + } else { + (weight_below + weight_delta / 2.) / centroids_weight + }, + ) + } + + fn quantile(&self, rank: f64) -> Option { + debug_assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]"); + + if self.centroids.is_empty() { + return None; + } + + if self.centroids.len() == 1 { + return Some(self.centroids[0].mean); + } + + // at least 2 centroids + let centroids_weight = self.centroids_weight as f64; + let num_centroids = self.centroids.len(); + let weight = rank * centroids_weight; + if weight < 1. { + return Some(self.min); + } + if weight > centroids_weight - 1. { + return Some(self.max); + } + let first_weight = self.centroids[0].weight(); + if first_weight > 1. && weight < first_weight / 2. { + return Some( + self.min + + (((weight - 1.) / ((first_weight / 2.) - 1.)) + * (self.centroids[0].mean - self.min)), + ); + } + let last_weight = self.centroids[num_centroids - 1].weight(); + if last_weight > 1. && (centroids_weight - weight <= last_weight / 2.) { + return Some( + self.max + + (((centroids_weight - weight - 1.) / ((last_weight / 2.) - 1.)) + * (self.max - self.centroids[num_centroids - 1].mean)), + ); + } + + // interpolate between extremes + let mut weight_so_far = first_weight / 2.; + for i in 0..(num_centroids - 1) { + let dw = (self.centroids[i].weight() + self.centroids[i + 1].weight()) / 2.; + if weight_so_far + dw > weight { + // the target weight is between centroids i and i+1 + let mut left_weight = 0.; + if self.centroids[i].weight.get() == 1 { + if weight - weight_so_far < 0.5 { + return Some(self.centroids[i].mean); + } + left_weight = 0.5; + } + let mut right_weight = 0.; + if self.centroids[i + 1].weight.get() == 1 { + if weight_so_far + dw - weight < 0.5 { + return Some(self.centroids[i + 1].mean); + } + right_weight = 0.5; + } + let w1 = weight - weight_so_far - left_weight; + let w2 = weight_so_far + dw - weight - right_weight; + return Some(weighted_average( + self.centroids[i].mean, + w1, + self.centroids[i + 1].mean, + w2, + )); + } + weight_so_far += dw; + } + + let w1 = weight - (centroids_weight) - ((self.centroids[num_centroids - 1].weight()) / 2.); + let w2 = (self.centroids[num_centroids - 1].weight() / 2.) - w1; + Some(weighted_average( + self.centroids[num_centroids - 1].mean, + w1, + self.max, + w2, + )) + } +} + +/// Checks the sequential validity of the given array of double values. +/// They must be unique, monotonically increasing and not NaN. +#[track_caller] +fn check_split_points(split_points: &[f64]) { + let len = split_points.len(); + if len == 1 && split_points[0].is_nan() { + panic!("split_points must not contain NaN values: {split_points:?}"); + } + for i in 0..len - 1 { + if split_points[i] < split_points[i + 1] { + // we must use this positive condition because NaN comparisons are always false + continue; + } + panic!("split_points must be unique and monotonically increasing: {split_points:?}"); + } +} + +fn centroid_cmp(a: &Centroid, b: &Centroid) -> Ordering { + match a.mean.partial_cmp(&b.mean) { + Some(order) => order, + None => unreachable!("NaN values should never be present in centroids"), + } +} + +fn centroid_lower_bound(c: &Centroid, value: f64) -> Ordering { + if c.mean < value { + Ordering::Less + } else { + Ordering::Greater + } +} + +fn centroid_upper_bound(c: &Centroid, value: f64) -> Ordering { + if c.mean > value { + Ordering::Greater + } else { + Ordering::Less + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +struct Centroid { + mean: f64, + weight: NonZeroU64, +} + +impl Centroid { + fn add(&mut self, other: Centroid) { + let (self_weight, other_weight) = (self.weight(), other.weight()); + let total_weight = self_weight + other_weight; + self.weight = self.weight.saturating_add(other.weight.get()); + + let (self_mean, other_mean) = (self.mean, other.mean); + let ratio_self = self_weight / total_weight; + let ratio_other = other_weight / total_weight; + self.mean = self_mean.mul_add(ratio_self, other_mean * ratio_other); + debug_assert!( + !self.mean.is_nan(), + "NaN values should never be present in centroids; self: {}, other: {}", + self_mean, + other_mean + ); + } + + fn weight(&self) -> f64 { + self.weight.get() as f64 + } +} + +fn check_non_nan(value: f64, tag: &'static str) -> Result<(), SerdeError> { + if value.is_nan() { + return Err(SerdeError::MalformedData(format!("{tag} cannot be NaN"))); + } + Ok(()) +} + +fn check_non_infinite(value: f64, tag: &'static str) -> Result<(), SerdeError> { + if value.is_infinite() { + return Err(SerdeError::MalformedData(format!( + "{tag} cannot be is_infinite" + ))); + } + Ok(()) +} + +fn check_nonzero(value: u64, tag: &'static str) -> Result { + NonZeroU64::new(value).ok_or_else(|| SerdeError::MalformedData(format!("{tag} cannot be zero"))) +} + +/// Generates cluster sizes proportional to `q*(1-q)`. +/// +/// The use of a normalizing function results in a strictly bounded number of clusters no matter +/// how many samples. +/// +/// Corresponds to K_2 in the reference implementation +mod scale_function { + pub(super) fn max(q: f64, normalizer: f64) -> f64 { + q * (1. - q) / normalizer + } + + pub(super) fn normalizer(compression: f64, n: f64) -> f64 { + compression / z(compression, n) + } + + pub(super) fn z(compression: f64, n: f64) -> f64 { + 4. * (n / compression).ln() + 24. + } +} + +const fn weighted_average(x1: f64, w1: f64, x2: f64, w2: f64) -> f64 { + (x1 * w1 + x2 * w2) / (w1 + w2) +} diff --git a/tests/common.rs b/tests/common.rs new file mode 100644 index 0000000..e97b920 --- /dev/null +++ b/tests/common.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::PathBuf; + +#[allow(dead_code)] // false-positive +pub fn test_data(name: &str) -> PathBuf { + const TEST_DATA_DIR: &str = "tests/test_data"; + + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(TEST_DATA_DIR) + .join(name) +} + +pub fn serialization_test_data(sub_dir: &str, name: &str) -> PathBuf { + const SERDE_TEST_DATA_DIR: &str = "tests/serialization_test_data"; + + let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(SERDE_TEST_DATA_DIR) + .join(sub_dir) + .join(name); + + if !path.exists() { + panic!( + r#"serialization test data file not found: {} + + Please ensure test data files are present in the repository. Generally, you can + run the following commands from the project root to regenerate the test data files + if they are missing: + + $ ./tools/generate_serialization_test_data.py + "#, + path.display(), + ); + } + + path +} diff --git a/tests/hll_serialization_test.rs b/tests/hll_serialization_test.rs index a3c397b..fc1969c 100644 --- a/tests/hll_serialization_test.rs +++ b/tests/hll_serialization_test.rs @@ -24,36 +24,14 @@ //! Test data is generated by the reference implementations and stored in: //! `tests/serialization_test_data/` +mod common; + use std::fs; use std::path::PathBuf; +use common::serialization_test_data; use datasketches::hll::HllSketch; -const TEST_DATA_DIR: &str = "tests/serialization_test_data"; - -fn get_test_data_path(sub_dir: &str, name: &str) -> PathBuf { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join(TEST_DATA_DIR) - .join(sub_dir) - .join(name); - - if !path.exists() { - panic!( - r#"serialization test data file not found: {} - - Please ensure test data files are present in the repository. Generally, you can - run the following commands from the project root to regenerate the test data files - if they are missing: - - $ ./tools/generate_serialization_test_data.py - "#, - path.display(), - ); - } - - path -} - fn test_sketch_file(path: PathBuf, expected_cardinality: usize, expected_lg_k: u8) { let expected = expected_cardinality as f64; @@ -133,7 +111,7 @@ fn test_java_hll4_compatibility() { for n in test_cases { let filename = format!("hll4_n{}_java.sk", n); - let path = get_test_data_path("java_generated_files", &filename); + let path = serialization_test_data("java_generated_files", &filename); test_sketch_file(path, n, 12); } } @@ -144,7 +122,7 @@ fn test_java_hll6_compatibility() { for n in test_cases { let filename = format!("hll6_n{}_java.sk", n); - let path = get_test_data_path("java_generated_files", &filename); + let path = serialization_test_data("java_generated_files", &filename); test_sketch_file(path, n, 12); } } @@ -155,7 +133,7 @@ fn test_java_hll8_compatibility() { for n in test_cases { let filename = format!("hll8_n{}_java.sk", n); - let path = get_test_data_path("java_generated_files", &filename); + let path = serialization_test_data("java_generated_files", &filename); test_sketch_file(path, n, 12); } } @@ -166,7 +144,7 @@ fn test_cpp_hll4_compatibility() { for n in test_cases { let filename = format!("hll4_n{}_cpp.sk", n); - let path = get_test_data_path("cpp_generated_files", &filename); + let path = serialization_test_data("cpp_generated_files", &filename); test_sketch_file(path, n, 12); } } @@ -177,7 +155,7 @@ fn test_cpp_hll6_compatibility() { for n in test_cases { let filename = format!("hll6_n{}_cpp.sk", n); - let path = get_test_data_path("cpp_generated_files", &filename); + let path = serialization_test_data("cpp_generated_files", &filename); test_sketch_file(path, n, 12); } } @@ -188,7 +166,7 @@ fn test_cpp_hll8_compatibility() { for n in test_cases { let filename = format!("hll8_n{}_cpp.sk", n); - let path = get_test_data_path("cpp_generated_files", &filename); + let path = serialization_test_data("cpp_generated_files", &filename); test_sketch_file(path, n, 12); } } @@ -208,7 +186,7 @@ fn test_estimate_accuracy() { println!("{:-<40}", ""); for (dir, file, expected) in test_cases { - let path = get_test_data_path(dir, file); + let path = serialization_test_data(dir, file); let bytes = fs::read(&path).unwrap(); let sketch = HllSketch::deserialize(&bytes).unwrap(); let estimate = sketch.estimate(); diff --git a/tests/tdigest_serialization_test.rs b/tests/tdigest_serialization_test.rs new file mode 100644 index 0000000..0ad68e4 --- /dev/null +++ b/tests/tdigest_serialization_test.rs @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +use std::fs; +use std::path::PathBuf; + +use common::serialization_test_data; +use common::test_data; +use datasketches::tdigest::TDigestMut; +use googletest::assert_that; +use googletest::prelude::{eq, near}; + +fn test_sketch_file(path: PathBuf, n: u64, with_buffer: bool, is_f32: bool) { + let bytes = fs::read(&path).unwrap(); + let td = TDigestMut::deserialize(&bytes, is_f32).unwrap(); + let td = td.freeze(); + + let path = path.display(); + if n == 0 { + assert!(td.is_empty(), "filepath: {path}"); + assert_eq!(td.total_weight(), 0, "filepath: {path}"); + } else { + assert!(!td.is_empty(), "filepath: {path}"); + assert_eq!(td.total_weight(), n, "filepath: {path}"); + assert_eq!(td.min_value(), Some(1.0), "filepath: {path}"); + assert_eq!(td.max_value(), Some(n as f64), "filepath: {path}"); + assert_eq!(td.rank(0.0), Some(0.0), "filepath: {path}"); + assert_eq!(td.rank((n + 1) as f64), Some(1.0), "filepath: {path}"); + if n == 1 { + assert_eq!(td.rank(n as f64), Some(0.5), "filepath: {path}"); + } else { + assert_that!( + td.rank(n as f64 / 2.).unwrap(), + near(0.5, 0.05), + "filepath: {path}", + ); + } + } + + if !with_buffer && !is_f32 { + let mut td = td.unfreeze(); + let roundtrip_bytes = td.serialize(); + assert_eq!(bytes, roundtrip_bytes, "filepath: {path}"); + } +} + +#[test] +fn test_deserialize_from_cpp_snapshots() { + let ns = [0, 1, 10, 100, 1000, 10_000, 100_000, 1_000_000]; + for n in ns { + let filename = format!("tdigest_double_n{}_cpp.sk", n); + let path = serialization_test_data("cpp_generated_files", &filename); + test_sketch_file(path, n, false, false); + } + for n in ns { + let filename = format!("tdigest_double_buf_n{}_cpp.sk", n); + let path = serialization_test_data("cpp_generated_files", &filename); + test_sketch_file(path, n, true, false); + } + for n in ns { + let filename = format!("tdigest_float_n{}_cpp.sk", n); + let path = serialization_test_data("cpp_generated_files", &filename); + test_sketch_file(path, n, false, true); + } + for n in ns { + let filename = format!("tdigest_float_buf_n{}_cpp.sk", n); + let path = serialization_test_data("cpp_generated_files", &filename); + test_sketch_file(path, n, true, true); + } +} + +#[test] +fn test_deserialize_from_reference_implementation() { + for filename in [ + "tdigest_ref_k100_n10000_double.sk", + "tdigest_ref_k100_n10000_float.sk", + ] { + let path = test_data(filename); + let bytes = fs::read(&path).unwrap(); + let td = TDigestMut::deserialize(&bytes, false).unwrap(); + let td = td.freeze(); + + let n = 10000; + let path = path.display(); + assert_eq!(td.k(), 100, "filepath: {path}"); + assert_eq!(td.total_weight(), n, "filepath: {path}"); + assert_eq!(td.min_value(), Some(0.0), "filepath: {path}"); + assert_eq!(td.max_value(), Some((n - 1) as f64), "filepath: {path}"); + assert_that!(td.rank(0.0).unwrap(), near(0.0, 0.0001), "filepath: {path}"); + assert_that!( + td.rank(n as f64 / 4.).unwrap(), + near(0.25, 0.0001), + "filepath: {path}" + ); + assert_that!( + td.rank(n as f64 / 2.).unwrap(), + near(0.5, 0.0001), + "filepath: {path}" + ); + assert_that!( + td.rank((n * 3) as f64 / 4.).unwrap(), + near(0.75, 0.0001), + "filepath: {path}" + ); + assert_that!(td.rank(n as f64).unwrap(), eq(1.0), "filepath: {path}"); + } +} + +#[test] +fn test_deserialize_from_java_snapshots() { + let ns = [0, 1, 10, 100, 1000, 10_000, 100_000, 1_000_000]; + for n in ns { + let filename = format!("tdigest_double_n{}_java.sk", n); + let path = serialization_test_data("java_generated_files", &filename); + test_sketch_file(path, n, false, false); + } +} + +#[test] +fn test_empty() { + let mut td = TDigestMut::new(100); + assert!(td.is_empty()); + + let bytes = td.serialize(); + assert_eq!(bytes.len(), 8); + let td = td.freeze(); + + let deserialized_td = TDigestMut::deserialize(&bytes, false).unwrap(); + let deserialized_td = deserialized_td.freeze(); + assert_eq!(td.k(), deserialized_td.k()); + assert_eq!(td.total_weight(), deserialized_td.total_weight()); + assert!(td.is_empty()); + assert!(deserialized_td.is_empty()); +} + +#[test] +fn test_single_value() { + let mut td = TDigestMut::default(); + td.update(123.0); + + let bytes = td.serialize(); + assert_eq!(bytes.len(), 16); + + let deserialized_td = TDigestMut::deserialize(&bytes, false).unwrap(); + let deserialized_td = deserialized_td.freeze(); + assert_eq!(deserialized_td.k(), 200); + assert_eq!(deserialized_td.total_weight(), 1); + assert!(!deserialized_td.is_empty()); + assert_eq!(deserialized_td.min_value(), Some(123.0)); + assert_eq!(deserialized_td.max_value(), Some(123.0)); +} + +#[test] +fn test_many_values() { + let mut td = TDigestMut::new(100); + for i in 0..1000 { + td.update(i as f64); + } + + let bytes = td.serialize(); + assert_eq!(bytes.len(), 1584); + let td = td.freeze(); + + let deserialized_td = TDigestMut::deserialize(&bytes, false).unwrap(); + let deserialized_td = deserialized_td.freeze(); + assert_eq!(td.k(), deserialized_td.k()); + assert_eq!(td.total_weight(), deserialized_td.total_weight()); + assert_eq!(td.is_empty(), deserialized_td.is_empty()); + assert_eq!(td.min_value(), deserialized_td.min_value()); + assert_eq!(td.max_value(), deserialized_td.max_value()); + assert_eq!(td.rank(500.0), deserialized_td.rank(500.0)); + assert_eq!(td.quantile(0.5), deserialized_td.quantile(0.5)); +} diff --git a/tests/tdigest_test.rs b/tests/tdigest_test.rs new file mode 100644 index 0000000..1ae1ae3 --- /dev/null +++ b/tests/tdigest_test.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datasketches::tdigest::TDigestMut; +use googletest::assert_that; +use googletest::prelude::{eq, near}; + +#[test] +fn test_empty() { + let mut tdigest = TDigestMut::new(10); + assert!(tdigest.is_empty()); + assert_eq!(tdigest.k(), 10); + assert_eq!(tdigest.total_weight(), 0); + assert_eq!(tdigest.min_value(), None); + assert_eq!(tdigest.max_value(), None); + assert_eq!(tdigest.rank(0.0), None); + assert_eq!(tdigest.quantile(0.5), None); + + let split_points = [0.0]; + assert_eq!(tdigest.pmf(&split_points), None); + assert_eq!(tdigest.cdf(&split_points), None); + + let tdigest = TDigestMut::new(10).freeze(); + assert!(tdigest.is_empty()); + assert_eq!(tdigest.k(), 10); + assert_eq!(tdigest.total_weight(), 0); + assert_eq!(tdigest.min_value(), None); + assert_eq!(tdigest.max_value(), None); + assert_eq!(tdigest.rank(0.0), None); + assert_eq!(tdigest.quantile(0.5), None); + + let split_points = [0.0]; + assert_eq!(tdigest.pmf(&split_points), None); + assert_eq!(tdigest.cdf(&split_points), None); +} + +#[test] +fn test_one_value() { + let mut tdigest = TDigestMut::new(100); + tdigest.update(1.0); + assert_eq!(tdigest.k(), 100); + assert_eq!(tdigest.total_weight(), 1); + assert_eq!(tdigest.min_value(), Some(1.0)); + assert_eq!(tdigest.max_value(), Some(1.0)); + assert_eq!(tdigest.rank(0.99), Some(0.0)); + assert_eq!(tdigest.rank(1.0), Some(0.5)); + assert_eq!(tdigest.rank(1.01), Some(1.0)); + assert_eq!(tdigest.quantile(0.0), Some(1.0)); + assert_eq!(tdigest.quantile(0.5), Some(1.0)); + assert_eq!(tdigest.quantile(1.0), Some(1.0)); +} + +#[test] +fn test_many_values() { + let n = 10000; + + let mut tdigest = TDigestMut::default(); + for i in 0..n { + tdigest.update(i as f64); + } + + assert!(!tdigest.is_empty()); + assert_eq!(tdigest.total_weight(), n); + assert_eq!(tdigest.min_value(), Some(0.0)); + assert_eq!(tdigest.max_value(), Some((n - 1) as f64)); + + assert_that!(tdigest.rank(0.0).unwrap(), near(0.0, 0.0001)); + assert_that!(tdigest.rank((n / 4) as f64).unwrap(), near(0.25, 0.0001)); + assert_that!(tdigest.rank((n / 2) as f64).unwrap(), near(0.5, 0.0001)); + assert_that!( + tdigest.rank((n * 3 / 4) as f64).unwrap(), + near(0.75, 0.0001) + ); + assert_that!(tdigest.rank(n as f64).unwrap(), eq(1.0)); + assert_that!(tdigest.quantile(0.0).unwrap(), eq(0.0)); + assert_that!( + tdigest.quantile(0.5).unwrap(), + near((n / 2) as f64, 0.03 * (n / 2) as f64) + ); + assert_that!( + tdigest.quantile(0.9).unwrap(), + near((n as f64) * 0.9, 0.01 * (n as f64) * 0.9) + ); + assert_that!( + tdigest.quantile(0.95).unwrap(), + near((n as f64) * 0.95, 0.01 * (n as f64) * 0.95) + ); + assert_that!(tdigest.quantile(1.0).unwrap(), eq((n - 1) as f64)); + + let split_points = [n as f64 / 2.0]; + let pmf = tdigest.pmf(&split_points).unwrap(); + assert_eq!(pmf.len(), 2); + assert_that!(pmf[0], near(0.5, 0.0001)); + assert_that!(pmf[1], near(0.5, 0.0001)); + let cdf = tdigest.cdf(&split_points).unwrap(); + assert_eq!(cdf.len(), 2); + assert_that!(cdf[0], near(0.5, 0.0001)); + assert_that!(cdf[1], eq(1.0)); +} + +#[test] +fn test_rank_two_values() { + let mut tdigest = TDigestMut::new(100); + tdigest.update(1.0); + tdigest.update(2.0); + assert_eq!(tdigest.rank(0.99), Some(0.0)); + assert_eq!(tdigest.rank(1.0), Some(0.25)); + assert_eq!(tdigest.rank(1.25), Some(0.375)); + assert_eq!(tdigest.rank(1.5), Some(0.5)); + assert_eq!(tdigest.rank(1.75), Some(0.625)); + assert_eq!(tdigest.rank(2.0), Some(0.75)); + assert_eq!(tdigest.rank(2.01), Some(1.0)); +} + +#[test] +fn test_rank_repeated_values() { + let mut tdigest = TDigestMut::new(100); + tdigest.update(1.0); + tdigest.update(1.0); + tdigest.update(1.0); + tdigest.update(1.0); + assert_eq!(tdigest.rank(0.99), Some(0.0)); + assert_eq!(tdigest.rank(1.0), Some(0.5)); + assert_eq!(tdigest.rank(1.01), Some(1.0)); +} + +#[test] +fn test_repeated_blocks() { + let mut tdigest = TDigestMut::new(100); + tdigest.update(1.0); + tdigest.update(2.0); + tdigest.update(2.0); + tdigest.update(3.0); + assert_eq!(tdigest.rank(0.99), Some(0.0)); + assert_eq!(tdigest.rank(1.0), Some(0.125)); + assert_eq!(tdigest.rank(2.0), Some(0.5)); + assert_eq!(tdigest.rank(3.0), Some(0.875)); + assert_eq!(tdigest.rank(3.01), Some(1.0)); +} + +#[test] +fn test_merge_small() { + let mut td1 = TDigestMut::new(10); + td1.update(1.0); + td1.update(2.0); + let mut td2 = TDigestMut::new(10); + td2.update(2.0); + td2.update(3.0); + td1.merge(&td2); + assert_eq!(td1.min_value(), Some(1.0)); + assert_eq!(td1.max_value(), Some(3.0)); + assert_eq!(td1.total_weight(), 4); + assert_eq!(td1.rank(0.99), Some(0.0)); + assert_eq!(td1.rank(1.0), Some(0.125)); + assert_eq!(td1.rank(2.0), Some(0.5)); + assert_eq!(td1.rank(3.0), Some(0.875)); + assert_eq!(td1.rank(3.01), Some(1.0)); +} + +#[test] +fn test_merge_large() { + let n = 10000; + + let mut td1 = TDigestMut::new(10); + let mut td2 = TDigestMut::new(10); + let sup = n / 2; + for i in 0..sup { + td1.update(i as f64); + td2.update((sup + i) as f64); + } + td1.merge(&td2); + + assert_eq!(td1.total_weight(), n); + assert_eq!(td1.min_value(), Some(0.0)); + assert_eq!(td1.max_value(), Some((n - 1) as f64)); + + assert_that!(td1.rank(0.0).unwrap(), near(0.0, 0.0001)); + assert_that!(td1.rank((n / 4) as f64).unwrap(), near(0.25, 0.0001)); + assert_that!(td1.rank((n / 2) as f64).unwrap(), near(0.5, 0.0001)); + assert_that!(td1.rank((n * 3 / 4) as f64).unwrap(), near(0.75, 0.0001)); + assert_that!(td1.rank(n as f64).unwrap(), eq(1.0)); +} + +#[test] +fn test_invalid_inputs() { + let n = 100; + + let mut td = TDigestMut::new(10); + for _ in 0..n { + td.update(f64::NAN); + } + assert!(td.is_empty()); + + let mut td = TDigestMut::new(10); + for _ in 0..n { + td.update(f64::INFINITY); + } + assert!(td.is_empty()); + + let mut td = TDigestMut::new(10); + for _ in 0..n { + td.update(f64::NEG_INFINITY); + } + assert!(td.is_empty()); + + let mut td = TDigestMut::new(10); + for i in 0..n { + if i % 2 == 0 { + td.update(f64::INFINITY); + } else { + td.update(f64::NEG_INFINITY); + } + } + assert!(td.is_empty()); +} diff --git a/tests/test_data/tdigest_ref_k100_n10000_double.sk b/tests/test_data/tdigest_ref_k100_n10000_double.sk new file mode 100644 index 0000000..f6f4510 Binary files /dev/null and b/tests/test_data/tdigest_ref_k100_n10000_double.sk differ diff --git a/tests/test_data/tdigest_ref_k100_n10000_float.sk b/tests/test_data/tdigest_ref_k100_n10000_float.sk new file mode 100644 index 0000000..16d7981 Binary files /dev/null and b/tests/test_data/tdigest_ref_k100_n10000_float.sk differ