diff --git a/client/src/lib.rs b/client/src/lib.rs index 9f5aaa96b..f48392006 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -63,13 +63,17 @@ use janus_core::{ }; use janus_messages::{ HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, Report, ReportId, - ReportMetadata, Role, TaskId, Time, UploadRequest, taskprov::TimePrecision, + ReportMetadata, ReportUploadStatus, Role, TaskId, Time, UploadRequest, UploadResponse, + taskprov::TimePrecision, }; #[cfg(feature = "ohttp")] use ohttp::{ClientRequest, KeyConfig}; #[cfg(feature = "ohttp")] use ohttp_keys::OhttpKeys; -use prio::{codec::Encode, vdaf}; +use prio::{ + codec::{Encode, ParameterizedDecode}, + vdaf, +}; use rand::random; use tokio::sync::Mutex; use url::Url; @@ -77,7 +81,6 @@ use url::Url; #[cfg(test)] mod tests; -// TODO(Issue #4146): need way to convey per-report errors #[derive(Debug, thiserror::Error)] pub enum Error { #[error("invalid parameter {0}")] @@ -88,6 +91,8 @@ pub enum Error { Codec(#[from] prio::codec::CodecError), #[error("HTTP response status {0}")] Http(Box), + #[error("upload failed for {} report(s)", .0.len())] + Upload(Vec), #[error("URL parse: {0}")] Url(#[from] url::ParseError), #[error("VDAF error: {0}")] @@ -627,11 +632,12 @@ impl> Client { .reports_resource_uri(&self.parameters.task_id)?; #[cfg(feature = "ohttp")] - let upload_status = self + let (upload_status, upload_response) = self .upload_with_ohttp(&upload_endpoint, &upload_request) .await?; #[cfg(not(feature = "ohttp"))] - let upload_status = self.put_report(&upload_endpoint, &upload_request).await?; + let (upload_status, upload_response) = + self.put_report(&upload_endpoint, &upload_request).await?; if !upload_status.is_success() { return Err(Error::Http(Box::new(HttpErrorResponse::from( @@ -639,6 +645,13 @@ impl> Client { )))); } + if let Some(upload_response) = upload_response { + let failed_reports = upload_response.status(); + if !failed_reports.is_empty() { + return Err(Error::Upload(failed_reports.to_vec())); + } + } + Ok(()) } @@ -646,8 +659,8 @@ impl> Client { &self, upload_endpoint: &Url, request_body: &[u8], - ) -> Result { - Ok(retry_http_request( + ) -> Result<(StatusCode, Option), Error> { + let response = retry_http_request( self.parameters.http_request_retry_parameters.build(), || async { self.http_client @@ -658,19 +671,36 @@ impl> Client { .await }, ) - .await? - .status()) + .await?; + + let status = response.status(); + let upload_response = if response + .headers() + .get(CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + == Some(UploadResponse::MEDIA_TYPE) + { + let body = response.body(); + Some(UploadResponse::get_decoded_with_param( + &body.len(), + body.as_ref(), + )?) + } else { + None + }; + + Ok((status, upload_response)) } /// Send a DAP upload request via OHTTP, if the client is configured to use it, or directly if - /// not. This is only intended for DAP uploads and so does not handle response bodies. + /// not. #[cfg(feature = "ohttp")] #[tracing::instrument(skip(self, request_body), err)] async fn upload_with_ohttp( &self, upload_endpoint: &Url, request_body: &[u8], - ) -> Result { + ) -> Result<(StatusCode, Option), Error> { let ohttp_config = if let Some(ohttp_config) = &self.ohttp_config { ohttp_config } else { @@ -750,7 +780,20 @@ impl> Client { )); }; - Ok(status) + let upload_response = message + .header() + .iter() + .find(|field| field.name() == CONTENT_TYPE.as_str().as_bytes()) + .and_then(|field| { + if field.value() == UploadResponse::MEDIA_TYPE.as_bytes() { + let content = message.content(); + UploadResponse::get_decoded_with_param(&content.len(), content).ok() + } else { + None + } + }); + + Ok((status, upload_response)) } } diff --git a/client/src/tests/mod.rs b/client/src/tests/mod.rs index aa28bd110..4e68590e0 100644 --- a/client/src/tests/mod.rs +++ b/client/src/tests/mod.rs @@ -7,7 +7,8 @@ use janus_core::{ test_util::install_test_trace_subscriber, }; use janus_messages::{ - HpkeConfigList, MediaType, Role, Time, UploadRequest, taskprov::TimePrecision, + HpkeConfigList, MediaType, ReportError, ReportUploadStatus, Role, Time, UploadRequest, + UploadResponse, taskprov::TimePrecision, }; use prio::{ codec::Encode, @@ -270,3 +271,109 @@ async fn unsupported_hpke_algorithms() { mock.assert_async().await; } + +#[tokio::test] +async fn upload_with_per_report_errors() { + install_test_trace_subscriber(); + initialize_rustls(); + let mut server = mockito::Server::new_async().await; + let client = setup_client(&server, Prio3::new_count(2).unwrap()).await; + + // Create a report to determine its ID so we can create a matching error response + let report = client + .prepare_report( + &true, + &Time::from_seconds_since_epoch(100, &client.parameters.time_precision), + client.leader_hpke_config.lock().await.get().await.unwrap(), + client.helper_hpke_config.lock().await.get().await.unwrap(), + ) + .unwrap(); + let report_id = *report.metadata().id(); + + // Create an UploadResponse with a per-report error + let upload_response = UploadResponse::new(&[ReportUploadStatus::new( + report_id, + ReportError::ReportReplayed, + )]); + let response_body = upload_response.get_encoded().unwrap(); + + let mocked_upload = server + .mock( + "POST", + format!("/tasks/{}/reports", client.parameters.task_id).as_str(), + ) + .match_header(CONTENT_TYPE.as_str(), UploadRequest::MEDIA_TYPE) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), UploadResponse::MEDIA_TYPE) + .with_body(response_body) + .expect(1) + .create_async() + .await; + + // Upload should fail even though HTTP status is 200 + let result = client.upload(true).await; + assert_matches!( + result, + Err(Error::Upload(statuses)) => { + assert_eq!(statuses.len(), 1); + assert_eq!(statuses[0].report_id(), report_id); + assert_eq!(statuses[0].error(), ReportError::ReportReplayed); + } + ); + + mocked_upload.assert_async().await; +} + +#[tokio::test] +async fn upload_success_without_response_body() { + install_test_trace_subscriber(); + initialize_rustls(); + let mut server = mockito::Server::new_async().await; + let client = setup_client(&server, Prio3::new_count(2).unwrap()).await; + + let mocked_upload = server + .mock( + "POST", + format!("/tasks/{}/reports", client.parameters.task_id).as_str(), + ) + .match_header(CONTENT_TYPE.as_str(), UploadRequest::MEDIA_TYPE) + .with_status(200) + .expect(1) + .create_async() + .await; + + // Upload should succeed when HTTP 200 is returned without a response body + client.upload(true).await.unwrap(); + + mocked_upload.assert_async().await; +} + +#[tokio::test] +async fn upload_success_with_empty_response() { + install_test_trace_subscriber(); + initialize_rustls(); + let mut server = mockito::Server::new_async().await; + let client = setup_client(&server, Prio3::new_count(2).unwrap()).await; + + // Empty UploadResponse (no errors) + let upload_response = UploadResponse::new(&[]); + let response_body = upload_response.get_encoded().unwrap(); + + let mocked_upload = server + .mock( + "POST", + format!("/tasks/{}/reports", client.parameters.task_id).as_str(), + ) + .match_header(CONTENT_TYPE.as_str(), UploadRequest::MEDIA_TYPE) + .with_status(200) + .with_header(CONTENT_TYPE.as_str(), UploadResponse::MEDIA_TYPE) + .with_body(response_body) + .expect(1) + .create_async() + .await; + + // Upload should succeed when HTTP 200 is returned with an empty UploadResponse + client.upload(true).await.unwrap(); + + mocked_upload.assert_async().await; +}