From 0b057f694e1f7c973435df03d9ac7ecfe2997e16 Mon Sep 17 00:00:00 2001 From: "J.C. Jones" Date: Tue, 9 Sep 2025 13:33:54 -0700 Subject: [PATCH] Add an InstantClock trait and plumb it into the LIFOQueueHandler and Datastore. Follow-on to PR #4046 Resolves #4044 --- aggregator/src/aggregator/queue.rs | 119 +++++++++++----- aggregator/src/binary_utils.rs | 8 +- aggregator_core/src/datastore.rs | 50 +++++-- aggregator_core/src/datastore/test_util.rs | 14 +- aggregator_core/src/datastore/tests.rs | 1 + core/src/time.rs | 156 ++++++++++++++++++++- 6 files changed, 296 insertions(+), 52 deletions(-) diff --git a/aggregator/src/aggregator/queue.rs b/aggregator/src/aggregator/queue.rs index e6e60af3b..7a140442b 100644 --- a/aggregator/src/aggregator/queue.rs +++ b/aggregator/src/aggregator/queue.rs @@ -1,5 +1,6 @@ use itertools::Itertools; use janus_aggregator_core::TIME_HISTOGRAM_BOUNDARIES; +use janus_core::time::{InstantClock, InstantLike, RealInstantClock}; use opentelemetry::{ KeyValue, metrics::{Histogram, Meter}, @@ -11,7 +12,6 @@ use std::{ Arc, atomic::{AtomicU64, Ordering}, }, - time::Instant, }; use tokio::{ select, @@ -38,7 +38,7 @@ use super::Error; /// because the order that request futures are scheduled and executed in is essentially /// non-deterministic. #[derive(Debug)] -pub struct LIFORequestQueue { +pub struct LIFORequestQueue { /// Sends messages to the dispatcher task. dispatcher_tx: mpsc::UnboundedSender, @@ -48,11 +48,13 @@ pub struct LIFORequestQueue { /// counter. id_counter: AtomicU64, - metrics: Metrics, + metrics: Metrics, + + instant_clock: C, } impl LIFORequestQueue { - /// Creates a new [`Self`]. + /// Creates a new [`Self`] with the real clock. /// /// `concurrency` must be greater than zero. /// @@ -63,6 +65,24 @@ impl LIFORequestQueue { depth: usize, meter: &Meter, meter_prefix: &str, + ) -> Result { + Self::with_instant_clock(concurrency, depth, meter, meter_prefix, RealInstantClock) + } +} + +impl LIFORequestQueue { + /// Creates a new [`Self`] with a custom instant clock. + /// + /// `concurrency` must be greater than zero. + /// + /// `meter_prefix` is a string to disambiguate one queue from another in the metrics, while + /// using the same meter. All metric names will be prefixed with this string. + pub fn with_instant_clock( + concurrency: u32, + depth: usize, + meter: &Meter, + meter_prefix: &str, + instant_clock: C, ) -> Result { if concurrency < 1 { return Err(Error::InvalidConfiguration( @@ -82,6 +102,7 @@ impl LIFORequestQueue { id_counter, dispatcher_tx: message_tx, metrics, + instant_clock, }) } @@ -98,7 +119,7 @@ impl LIFORequestQueue { mut dispatcher_rx: mpsc::UnboundedReceiver, concurrency: u32, depth: usize, - metrics: Metrics, + metrics: Metrics, ) -> JoinHandle<()> { tokio::spawn(async move { // Use a BTreeMap to allow for cancellation (i.e. removal) of waiting requests in @@ -184,7 +205,7 @@ impl LIFORequestQueue { let id = self.id_counter.fetch_add(1, Ordering::Relaxed); let (permit_tx, permit_rx) = oneshot::channel(); - let enqueue_time = Instant::now(); + let enqueue_time = self.instant_clock.now(); self.dispatcher_tx .send(DispatcherMessage::Enqueue(id, PermitTx(permit_tx))) // We don't necessarily panic because the dispatcher task could be shutdown as part of @@ -193,20 +214,20 @@ impl LIFORequestQueue { /// Sends a cancellation message over the given channel when the guard is dropped, unless /// [`Self::disarm`] is called. - struct CancelDropGuard { + struct CancelDropGuard { id: u64, sender: mpsc::UnboundedSender, armed: bool, - metrics: Metrics, - enqueue_time: Instant, + metrics: Metrics, + enqueue_time: I, } - impl CancelDropGuard { + impl CancelDropGuard { fn new( id: u64, sender: mpsc::UnboundedSender, - metrics: Metrics, - enqueue_time: Instant, + metrics: Metrics, + enqueue_time: I, ) -> Self { Self { id, @@ -222,7 +243,7 @@ impl LIFORequestQueue { } } - impl Drop for CancelDropGuard { + impl Drop for CancelDropGuard { fn drop(&mut self) { if self.armed { self.metrics.wait_time_histogram.record( @@ -237,11 +258,11 @@ impl LIFORequestQueue { } } - let mut drop_guard = CancelDropGuard::new( + let mut drop_guard = CancelDropGuard::::new( id, self.dispatcher_tx.clone(), self.metrics.clone(), - enqueue_time, + enqueue_time.clone(), ); let permit = permit_rx.await; drop_guard.disarm(); @@ -288,14 +309,14 @@ impl PermitTx { /// /// Multiple request handlers can share a queue, by cloning the [`Arc`] that wraps the queue. #[derive(Handler)] -pub struct LIFOQueueHandler { +pub struct LIFOQueueHandler { #[handler(except = [run])] handler: H, - queue: Arc, + queue: Arc>, } -impl LIFOQueueHandler { - pub fn new(queue: Arc, handler: H) -> Self { +impl LIFOQueueHandler { + pub fn new(queue: Arc>, handler: H) -> Self { Self { handler, queue } } @@ -311,12 +332,15 @@ impl LIFOQueueHandler { } /// Convenience function for wrapping a handler with a [`LIFOQueueHandler`]. -pub fn queued_lifo(queue: Arc, handler: H) -> impl Handler { +pub fn queued_lifo( + queue: Arc>, + handler: H, +) -> LIFOQueueHandler { LIFOQueueHandler::new(queue, handler) } #[derive(Clone, Debug)] -struct Metrics { +struct Metrics { /// The approximate number of requests currently being serviced by the queue. It's approximate /// since the queue length may have changed before the measurement is taken. In practice, the /// error should only be +/- 1. It is also more or less suitable for synchronization during @@ -325,9 +349,11 @@ struct Metrics { /// Histogram measuring how long a queue item waited before being dequeued. wait_time_histogram: Histogram, + + _phantom: std::marker::PhantomData, } -impl Metrics { +impl Metrics { const OUTSTANDING_REQUESTS_METRIC_NAME: &'static str = "outstanding_requests"; const MAX_OUTSTANDING_REQUESTS_METRIC_NAME: &'static str = "max_outstanding_requests"; const WAIT_TIME_METRIC_NAME: &'static str = "lifo_queue_wait_time"; @@ -378,6 +404,7 @@ impl Metrics { Ok(Self { outstanding_requests, wait_time_histogram, + _phantom: std::marker::PhantomData, }) } } @@ -396,14 +423,17 @@ mod tests { use backon::{BackoffBuilder, ExponentialBuilder, Retryable}; use futures::{Future, future::join_all}; use janus_aggregator_core::test_util::noop_meter; - use janus_core::test_util::install_test_trace_subscriber; + use janus_core::{ + test_util::install_test_trace_subscriber, + time::{MockInstantClock, RealInstantClock}, + }; use opentelemetry_sdk::metrics::data::Gauge; use quickcheck::{Arbitrary, TestResult, quickcheck}; use tokio::{ runtime::Builder as RuntimeBuilder, sync::Notify, task::{JoinHandle, yield_now}, - time::{sleep, timeout}, + time::timeout, }; use tracing::debug; use trillium::{Conn, Handler, Status}; @@ -430,9 +460,9 @@ mod tests { .await // The metric may not be immediately available when we need it, so return an Option // instead of unwrapping. - .get(&Metrics::metric_name( + .get(&Metrics::::metric_name( meter_prefix, - Metrics::OUTSTANDING_REQUESTS_METRIC_NAME, + Metrics::::OUTSTANDING_REQUESTS_METRIC_NAME, ))? .data .as_any() @@ -455,8 +485,8 @@ mod tests { .map(|metric| !condition(metric)) .unwrap_or(true) { - // Nominal sleep to prevent this loop from being too tight. - sleep(Duration::from_millis(3)).await; + // Yield to allow other tasks to run + yield_now().await; } } @@ -635,9 +665,16 @@ mod tests { let meter_prefix = "test"; let metrics = InMemoryMetricInfrastructure::new(); let unhang = Arc::new(Notify::new()); + let instant_clock = MockInstantClock::new(); let queue = Arc::new( - LIFORequestQueue::new(concurrency, depth, &metrics.meter, meter_prefix) - .unwrap(), + LIFORequestQueue::with_instant_clock( + concurrency, + depth, + &metrics.meter, + meter_prefix, + instant_clock, + ) + .unwrap(), ); let handler = Arc::new(queued_lifo( Arc::clone(&queue), @@ -716,9 +753,16 @@ mod tests { let unhang = Arc::new(Notify::new()); let meter_prefix = "test"; let metrics = InMemoryMetricInfrastructure::new(); + let instant_clock = MockInstantClock::new(); let queue = Arc::new( - LIFORequestQueue::new(concurrency, depth, &metrics.meter, meter_prefix) - .unwrap(), + LIFORequestQueue::with_instant_clock( + concurrency, + depth, + &metrics.meter, + meter_prefix, + instant_clock, + ) + .unwrap(), ); let handler = Arc::new(queued_lifo( Arc::clone(&queue), @@ -781,9 +825,16 @@ mod tests { let unhang = Arc::new(Notify::new()); let meter_prefix = "test"; let metrics = InMemoryMetricInfrastructure::new(); + let instant_clock = MockInstantClock::new(); let queue = Arc::new( - LIFORequestQueue::new(concurrency, depth, &metrics.meter, meter_prefix) - .unwrap(), + LIFORequestQueue::with_instant_clock( + concurrency, + depth, + &metrics.meter, + meter_prefix, + instant_clock, + ) + .unwrap(), ); let handler = Arc::new(queued_lifo( Arc::clone(&queue), diff --git a/aggregator/src/binary_utils.rs b/aggregator/src/binary_utils.rs index 8c8bf3bff..1ba435be6 100644 --- a/aggregator/src/binary_utils.rs +++ b/aggregator/src/binary_utils.rs @@ -139,7 +139,7 @@ pub async fn datastore( datastore_keys: &[String], check_schema_version: bool, max_transaction_retries: u64, -) -> Result> { +) -> Result> { let datastore_keys = datastore_keys .iter() .filter(|k| !k.is_empty()) @@ -165,21 +165,25 @@ pub async fn datastore( } let datastore = if check_schema_version { - Datastore::new( + use janus_core::time::RealInstantClock; + Datastore::with_instant_clock( pool, Crypter::new(datastore_keys), clock, meter, max_transaction_retries, + RealInstantClock, ) .await? } else { + use janus_core::time::RealInstantClock; Datastore::new_without_supported_versions( pool, Crypter::new(datastore_keys), clock, meter, max_transaction_retries, + RealInstantClock, ) .await }; diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 9c3b4ed2c..c7e307468 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -22,7 +22,10 @@ use futures::future::try_join_all; use janus_core::{ auth_tokens::AuthenticationToken, hpke::{self, HpkePrivateKey}, - time::{Clock, DurationExt, IntervalExt, TimeExt}, + time::{ + Clock, DurationExt, InstantClock, InstantLike, IntervalExt, RealInstantClock, + TimeExt, + }, vdaf::VdafInstance, }; use janus_messages::{ @@ -55,7 +58,7 @@ use std::{ Arc, Mutex, atomic::{AtomicBool, Ordering}, }, - time::{Duration as StdDuration, Instant}, + time::Duration as StdDuration, }; use tokio::{sync::Barrier, try_join}; use tokio_postgres::{IsolationLevel, Row, Statement, ToStatement, error::SqlState, row::RowIndex}; @@ -110,7 +113,7 @@ supported_schema_versions!(1); /// Datastore represents a datastore for Janus, with support for transactional reads and writes. /// In practice, Datastore instances are currently backed by a PostgreSQL database. -pub struct Datastore { +pub struct Datastore { pool: deadpool_postgres::Pool, crypter: Crypter, clock: C, @@ -122,9 +125,10 @@ pub struct Datastore { transaction_pool_wait_histogram: Histogram, transaction_total_duration_histogram: Histogram, max_transaction_retries: u64, + instant_clock: I, } -impl Debug for Datastore { +impl Debug for Datastore { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Datastore") } @@ -139,7 +143,28 @@ impl Datastore { clock: C, meter: &Meter, max_transaction_retries: u64, - ) -> Result, Error> { + ) -> Result { + Self::with_instant_clock( + pool, + crypter, + clock, + meter, + max_transaction_retries, + RealInstantClock, + ) + .await + } +} + +impl Datastore { + pub async fn with_instant_clock( + pool: deadpool_postgres::Pool, + crypter: Crypter, + clock: C, + meter: &Meter, + max_transaction_retries: u64, + instant_clock: I, + ) -> Result, Error> { Self::new_with_supported_versions( pool, crypter, @@ -147,6 +172,7 @@ impl Datastore { meter, SUPPORTED_SCHEMA_VERSIONS, max_transaction_retries, + instant_clock, ) .await } @@ -158,13 +184,15 @@ impl Datastore { meter: &Meter, supported_schema_versions: &[i64], max_transaction_retries: u64, - ) -> Result, Error> { + instant_clock: I, + ) -> Result, Error> { let datastore = Self::new_without_supported_versions( pool, crypter, clock, meter, max_transaction_retries, + instant_clock, ) .await; @@ -190,7 +218,8 @@ impl Datastore { clock: C, meter: &Meter, max_transaction_retries: u64, - ) -> Datastore { + instant_clock: I, + ) -> Datastore { let transaction_status_counter = meter .u64_counter(TRANSACTION_METER_NAME) .with_description("Count of database transactions run, with their status.") @@ -250,6 +279,7 @@ impl Datastore { transaction_pool_wait_histogram, transaction_total_duration_histogram, max_transaction_retries, + instant_clock, } } @@ -269,7 +299,7 @@ impl Datastore { for<'a> F: Fn(&'a Transaction) -> Pin> + Send + 'a>>, { - let before = Instant::now(); + let before = self.instant_clock.now(); let mut retry_count = 0; loop { let (mut rslt, retry) = self.run_tx_once(name, &f).await; @@ -323,7 +353,7 @@ impl Datastore { Fn(&'a Transaction) -> Pin> + Send + 'a>>, { // Acquire connection from the connection pooler. - let before = Instant::now(); + let before = self.instant_clock.now(); let result = self.pool.get().await; let elapsed = before.elapsed(); // We don't record the transaction name for this metric, since it's not particularly @@ -341,7 +371,7 @@ impl Datastore { }; // Open transaction. - let before = Instant::now(); + let before = self.instant_clock.now(); let raw_tx = match client .build_transaction() .isolation_level(IsolationLevel::RepeatableRead) diff --git a/aggregator_core/src/datastore/test_util.rs b/aggregator_core/src/datastore/test_util.rs index bbea54e12..d52fd0eec 100644 --- a/aggregator_core/src/datastore/test_util.rs +++ b/aggregator_core/src/datastore/test_util.rs @@ -8,7 +8,7 @@ use chrono::NaiveDateTime; use deadpool_postgres::{Manager, Pool, Timeouts}; use janus_core::{ test_util::testcontainers::Postgres, - time::{Clock, MockClock, TimeExt}, + time::{Clock, MockClock, RealInstantClock, TimeExt}, }; use janus_messages::Time; use rand::{Rng, distr::StandardUniform, random, rng}; @@ -217,13 +217,15 @@ pub const TEST_DATASTORE_MAX_TRANSACTION_RETRIES: u64 = 1000; impl EphemeralDatastore { /// Creates a Datastore instance based on this EphemeralDatastore. All returned Datastore /// instances will refer to the same underlying durable state. - pub async fn datastore(&self, clock: C) -> Datastore { - Datastore::new( + pub async fn datastore(&self, clock: C) -> Datastore { + use RealInstantClock; + Datastore::with_instant_clock( self.pool(), self.crypter(), clock, &noop_meter(), TEST_DATASTORE_MAX_TRANSACTION_RETRIES, + RealInstantClock, ) .await .unwrap() @@ -233,13 +235,15 @@ impl EphemeralDatastore { &self, clock: C, max_transaction_retries: u64, - ) -> Datastore { - Datastore::new( + ) -> Datastore { + use RealInstantClock; + Datastore::with_instant_clock( self.pool(), self.crypter(), clock, &noop_meter(), max_transaction_retries, + RealInstantClock, ) .await .unwrap() diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index 1364833f8..6d6b52a2f 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -88,6 +88,7 @@ async fn reject_unsupported_schema_version(ephemeral_datastore: EphemeralDatasto &noop_meter(), &[0], TEST_DATASTORE_MAX_TRANSACTION_RETRIES, + janus_core::time::RealInstantClock, ) .await .unwrap_err(); diff --git a/core/src/time.rs b/core/src/time.rs index 1aadec0b3..ae89ec7b3 100644 --- a/core/src/time.rs +++ b/core/src/time.rs @@ -5,6 +5,7 @@ use janus_messages::{Duration, Error, Interval, Time}; use std::{ fmt::{Debug, Formatter}, sync::{Arc, Mutex}, + time::{Duration as StdDuration, Instant as StdInstant}, }; /// A clock knows what time it currently is. @@ -409,10 +410,136 @@ impl IntervalExt for Interval { } } +/// A trait for types that behave like `std::time::Instant` for duration measurements. +pub trait InstantLike: Clone + Debug + Send + Sync { + /// Get the elapsed time since this instant. + fn elapsed(&self) -> StdDuration; +} + +/// Standard library Instant wrapper that implements InstantLike. +#[derive(Clone, Debug)] +pub struct RealInstant(StdInstant); + +impl RealInstant { + pub fn new(instant: StdInstant) -> Self { + Self(instant) + } +} + +impl InstantLike for RealInstant { + fn elapsed(&self) -> StdDuration { + self.0.elapsed() + } +} + +/// A clock that provides instants for duration measurements, separate from DAP time tracking. +pub trait InstantClock: Clone + Debug + Send + Sync { + type Instant: InstantLike; + + /// Get the current instant for timing measurements. + fn now(&self) -> Self::Instant; +} + +/// Real instant clock that uses the system clock. +#[derive(Clone, Copy, Default, Debug)] +pub struct RealInstantClock; + +impl InstantClock for RealInstantClock { + type Instant = RealInstant; + + fn now(&self) -> Self::Instant { + RealInstant::new(StdInstant::now()) + } +} + +#[cfg(feature = "test-util")] +/// Mock instant for testing duration measurements. +#[derive(Clone, Debug)] +pub struct MockInstant { + elapsed: Arc>, +} + +#[cfg(feature = "test-util")] +impl Default for MockInstant { + fn default() -> Self { + Self { + elapsed: Arc::new(Mutex::new(StdDuration::ZERO)), + } + } +} + +#[cfg(feature = "test-util")] +impl MockInstant { + /// Advance the mock elapsed time by the specified duration. + pub fn advance(&self, duration: StdDuration) { + let mut elapsed = self.elapsed.lock().unwrap(); + *elapsed += duration; + } + + pub fn new() -> Self { + Self::default() + } +} + +#[cfg(feature = "test-util")] +impl InstantLike for MockInstant { + fn elapsed(&self) -> StdDuration { + *self.elapsed.lock().unwrap() + } +} + +#[cfg(feature = "test-util")] +/// Mock instant clock for testing. +#[derive(Clone, Debug)] +pub struct MockInstantClock { + current_instant: MockInstant, +} + +#[cfg(feature = "test-util")] +impl Default for MockInstantClock { + fn default() -> Self { + Self { + current_instant: MockInstant::new(), + } + } +} + +#[cfg(feature = "test-util")] +impl MockInstantClock { + pub fn new() -> Self { + Self::default() + } + + /// Advance time by the specified duration. + pub fn advance(&self, duration: StdDuration) { + self.current_instant.advance(duration); + } + + /// Get a reference to the current instant for advancing time in tests. + pub fn current_instant(&self) -> &MockInstant { + &self.current_instant + } +} + +#[cfg(feature = "test-util")] +impl InstantClock for MockInstantClock { + type Instant = MockInstant; + + fn now(&self) -> Self::Instant { + self.current_instant.clone() + } +} + #[cfg(test)] mod tests { - use crate::time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}; + #[cfg(feature = "test-util")] + use crate::time::MockInstantClock; + use crate::time::{ + Clock, DurationExt, InstantClock, InstantLike, IntervalExt, MockClock, RealInstantClock, + TimeExt, + }; use janus_messages::{Duration, Interval, Time}; + use std::time::Duration as StdDuration; #[test] fn round_up_duration() { @@ -637,6 +764,33 @@ mod tests { } } + #[test] + fn instant_clock_basic() { + let instant = RealInstantClock.now(); + std::thread::sleep(std::time::Duration::from_millis(1)); + let elapsed = instant.elapsed(); + assert!(elapsed.as_millis() >= 1); + } + + #[cfg(feature = "test-util")] + #[test] + fn mock_instant_clock() { + let mock_clock = MockInstantClock::new(); + let instant = mock_clock.now(); + + // Initially, elapsed time should be zero + assert_eq!(instant.elapsed(), StdDuration::ZERO); + + // Advance time and check elapsed + let advance_duration = StdDuration::from_millis(100); + mock_clock.advance(advance_duration); + assert_eq!(instant.elapsed(), advance_duration); + + // Create a new instant and verify it has the advanced time + let new_instant = mock_clock.now(); + assert_eq!(new_instant.elapsed(), advance_duration); + } + #[test] fn now_aligned_to_precision() { for (label, timestamp, timestamp_precision, expected) in [