From 9e9f268af6873a3c08a2766f30051a33a84595f6 Mon Sep 17 00:00:00 2001 From: "J.C. Jones" Date: Wed, 4 Feb 2026 12:49:27 -0700 Subject: [PATCH] DAP-16 Leader Job State Updates Leader jobs should be transitioned from Active to Finished by the Driver, not by the Writer. Fixes #4320 --- .../src/aggregator/aggregation_job_driver.rs | 13 +- .../aggregation_job_driver/tests.rs | 379 ++++++++++++++++++ 2 files changed, 391 insertions(+), 1 deletion(-) diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index cc279251b..531c11650 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -1373,9 +1373,20 @@ where }), task_aggregation_counters.clone(), ); + + // Determine the next Leader job state based on whether all reports are terminal. + let all_terminal = report_aggregations_to_write + .iter() + .all(|ra| ra.is_terminal()); + let new_state = if all_terminal { + AggregationJobState::Finished + } else { + AggregationJobState::Active + }; + let new_step = aggregation_job.step().increment(); aggregation_job_writer.put( - aggregation_job.with_step(new_step), + aggregation_job.with_step(new_step).with_state(new_state), report_aggregations_to_write, )?; let aggregation_job_writer = Arc::new(aggregation_job_writer); diff --git a/aggregator/src/aggregator/aggregation_job_driver/tests.rs b/aggregator/src/aggregator/aggregation_job_driver/tests.rs index 86aba5f67..75c4095e1 100644 --- a/aggregator/src/aggregator/aggregation_job_driver/tests.rs +++ b/aggregator/src/aggregator/aggregation_job_driver/tests.rs @@ -6596,3 +6596,382 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { )]), ); } + +#[tokio::test] +async fn leader_job_state_transitions_to_active_on_continue() { + install_test_trace_subscriber(); + initialize_rustls(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); + let vdaf = Arc::new(dummy::Vdaf::new(2)); + + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + let leader_task = task.leader_view().unwrap(); + let time = clock.now().to_time(task.time_precision()); + let batch_identifier = TimeInterval::to_batch_identifier(&(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let aggregation_param = dummy::AggregationParam(7); + + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &13, + ); + + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let (task, aggregation_param, report) = + (leader_task.clone(), aggregation_param, report.clone()); + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::minimal(Time::from_time_precision_units(0)).unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation( + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), + ) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + batch_identifier, + aggregation_param, + 0, + Interval::minimal(time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0)) + }) + }) + .await + .unwrap(); + + let leader_request = AggregationJobInitializeReq::new( + aggregation_param.get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + Vec::from([PrepareInit::new( + ReportShare::new( + report.metadata().clone(), + report.public_share().get_encoded().unwrap(), + report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0] + .message() + .unwrap() + .clone(), + )]), + ); + let helper_response = AggregationJobResp { + prepare_resps: Vec::from([PrepareResp::new( + *report.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0] + .message() + .unwrap() + .clone(), + }, + )]), + }; + + server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id, None) + .unwrap() + .path(), + ) + .match_header( + CONTENT_TYPE.as_str(), + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .match_body(leader_request.get_encoded().unwrap()) + .with_status(201) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(0), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, + ); + aggregation_job_driver + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) + .await + .unwrap(); + + let got_aggregation_job = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + Box::pin(async move { + tx.get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + }) + }) + .await + .unwrap() + .unwrap(); + + assert_eq!(got_aggregation_job.state(), &AggregationJobState::Active); + assert_eq!(got_aggregation_job.step(), AggregationJobStep::from(1)); +} + +#[tokio::test] +async fn leader_job_state_transitions_to_finished_on_all_failures() { + install_test_trace_subscriber(); + initialize_rustls(); + let mut server = mockito::Server::new_async().await; + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); + let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); + let leader_task = task.leader_view().unwrap(); + let time = clock.now().to_time(task.time_precision()); + let batch_identifier = TimeInterval::to_batch_identifier(&(), &time).unwrap(); + let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); + + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &(), + &random(), + &false, + ); + + let helper_hpke_keypair = HpkeKeypair::test(); + let invalid_report = LeaderStoredReport::generate( + *task.id(), + ReportMetadata::new( + random(), + time, + Vec::from([ + Extension::new(ExtensionType::Reserved, Vec::new()), + Extension::new(ExtensionType::Reserved, Vec::new()), + ]), + ), + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let (task, report) = (leader_task.clone(), invalid_report.clone()); + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH_PRIO3, + TimeInterval, + Prio3Count, + >::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::minimal(Time::from_time_precision_units(0)).unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation( + &report.as_leader_init_report_aggregation(aggregation_job_id, 0), + ) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::< + VERIFY_KEY_LENGTH_PRIO3, + TimeInterval, + Prio3Count, + >::new( + *task.id(), + batch_identifier, + (), + 0, + Interval::minimal(time).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + Ok(tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0)) + }) + }) + .await + .unwrap(); + + let leader_request = AggregationJobInitializeReq::new( + ().get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + Vec::from([PrepareInit::new( + ReportShare::new( + invalid_report.metadata().clone(), + invalid_report.public_share().get_encoded().unwrap(), + invalid_report.helper_encrypted_input_share().clone(), + ), + transcript.leader_prepare_transitions[0] + .message() + .unwrap() + .clone(), + )]), + ); + let helper_response = AggregationJobResp { + prepare_resps: Vec::from([PrepareResp::new( + *invalid_report.metadata().id(), + PrepareStepResult::Reject(ReportError::InvalidMessage), + )]), + }; + + server + .mock( + "PUT", + task.aggregation_job_uri(&aggregation_job_id, None) + .unwrap() + .path(), + ) + .match_header( + CONTENT_TYPE.as_str(), + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .match_body(leader_request.get_encoded().unwrap()) + .with_status(201) + .with_header(CONTENT_TYPE.as_str(), AggregationJobResp::MEDIA_TYPE) + .with_body(helper_response.get_encoded().unwrap()) + .create_async() + .await; + + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(0), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, + ); + aggregation_job_driver + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) + .await + .unwrap(); + + let got_aggregation_job = ds + .run_unnamed_tx(|tx| { + let task = leader_task.clone(); + Box::pin(async move { + tx.get_aggregation_job::( + task.id(), + &aggregation_job_id, + ) + .await + }) + }) + .await + .unwrap() + .unwrap(); + + assert_eq!(got_aggregation_job.state(), &AggregationJobState::Finished); + assert_eq!(got_aggregation_job.step(), AggregationJobStep::from(1)); +}