Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,24 @@ use janus_core::{
};
use janus_messages::{
HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, Report, ReportId,
ReportMetadata, Role, TaskId, Time, UploadRequest, taskprov::TimePrecision,
ReportMetadata, ReportUploadStatus, Role, TaskId, Time, UploadRequest, UploadResponse,
taskprov::TimePrecision,
};
#[cfg(feature = "ohttp")]
use ohttp::{ClientRequest, KeyConfig};
#[cfg(feature = "ohttp")]
use ohttp_keys::OhttpKeys;
use prio::{codec::Encode, vdaf};
use prio::{
codec::{Encode, ParameterizedDecode},
vdaf,
};
use rand::random;
use tokio::sync::Mutex;
use url::Url;

#[cfg(test)]
mod tests;

// TODO(Issue #4146): need way to convey per-report errors
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("invalid parameter {0}")]
Expand All @@ -88,6 +91,8 @@ pub enum Error {
Codec(#[from] prio::codec::CodecError),
#[error("HTTP response status {0}")]
Http(Box<HttpErrorResponse>),
#[error("upload failed for {} report(s)", .0.len())]
Upload(Vec<ReportUploadStatus>),
#[error("URL parse: {0}")]
Url(#[from] url::ParseError),
#[error("VDAF error: {0}")]
Expand Down Expand Up @@ -627,27 +632,35 @@ impl<V: vdaf::Client<16>> Client<V> {
.reports_resource_uri(&self.parameters.task_id)?;

#[cfg(feature = "ohttp")]
let upload_status = self
let (upload_status, upload_response) = self
.upload_with_ohttp(&upload_endpoint, &upload_request)
.await?;
#[cfg(not(feature = "ohttp"))]
let upload_status = self.put_report(&upload_endpoint, &upload_request).await?;
let (upload_status, upload_response) =
self.put_report(&upload_endpoint, &upload_request).await?;

if !upload_status.is_success() {
return Err(Error::Http(Box::new(HttpErrorResponse::from(
upload_status,
))));
}

if let Some(upload_response) = upload_response {
let failed_reports = upload_response.status();
if !failed_reports.is_empty() {
return Err(Error::Upload(failed_reports.to_vec()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of a shame to have to copy the &[ReportUploadStatus] to Vec<ReportUploadStatus> since the enclosing struct gets dropped. We could do UploadResponse::take_status(self) -> Vec<ReportUploadStatus> and explicitly yoink the Vec. But honestly the compiler is probably smart enough to avoid the copy here, and it's not worth investing work until/unless we measure a performance problem here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be short, too, since it's only the failed jobs.

... should. :)

}
}

Ok(())
}

async fn put_report(
&self,
upload_endpoint: &Url,
request_body: &[u8],
) -> Result<StatusCode, Error> {
Ok(retry_http_request(
) -> Result<(StatusCode, Option<UploadResponse>), Error> {
let response = retry_http_request(
self.parameters.http_request_retry_parameters.build(),
|| async {
self.http_client
Expand All @@ -658,19 +671,36 @@ impl<V: vdaf::Client<16>> Client<V> {
.await
},
)
.await?
.status())
.await?;

let status = response.status();
let upload_response = if response
.headers()
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
== Some(UploadResponse::MEDIA_TYPE)
{
let body = response.body();
Some(UploadResponse::get_decoded_with_param(
&body.len(),
body.as_ref(),
)?)
} else {
None
};

Ok((status, upload_response))
}

/// Send a DAP upload request via OHTTP, if the client is configured to use it, or directly if
/// not. This is only intended for DAP uploads and so does not handle response bodies.
/// not.
#[cfg(feature = "ohttp")]
#[tracing::instrument(skip(self, request_body), err)]
async fn upload_with_ohttp(
&self,
upload_endpoint: &Url,
request_body: &[u8],
) -> Result<StatusCode, Error> {
) -> Result<(StatusCode, Option<UploadResponse>), Error> {
let ohttp_config = if let Some(ohttp_config) = &self.ohttp_config {
ohttp_config
} else {
Expand Down Expand Up @@ -750,7 +780,20 @@ impl<V: vdaf::Client<16>> Client<V> {
));
};

Ok(status)
let upload_response = message
.header()
.iter()
.find(|field| field.name() == CONTENT_TYPE.as_str().as_bytes())
.and_then(|field| {
if field.value() == UploadResponse::MEDIA_TYPE.as_bytes() {
let content = message.content();
UploadResponse::get_decoded_with_param(&content.len(), content).ok()
} else {
None
}
});

Ok((status, upload_response))
}
}

Expand Down
109 changes: 108 additions & 1 deletion client/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use janus_core::{
test_util::install_test_trace_subscriber,
};
use janus_messages::{
HpkeConfigList, MediaType, Role, Time, UploadRequest, taskprov::TimePrecision,
HpkeConfigList, MediaType, ReportError, ReportUploadStatus, Role, Time, UploadRequest,
UploadResponse, taskprov::TimePrecision,
};
use prio::{
codec::Encode,
Expand Down Expand Up @@ -270,3 +271,109 @@ async fn unsupported_hpke_algorithms() {

mock.assert_async().await;
}

#[tokio::test]
async fn upload_with_per_report_errors() {
install_test_trace_subscriber();
initialize_rustls();
let mut server = mockito::Server::new_async().await;
let client = setup_client(&server, Prio3::new_count(2).unwrap()).await;

// Create a report to determine its ID so we can create a matching error response
let report = client
.prepare_report(
&true,
&Time::from_seconds_since_epoch(100, &client.parameters.time_precision),
client.leader_hpke_config.lock().await.get().await.unwrap(),
client.helper_hpke_config.lock().await.get().await.unwrap(),
)
.unwrap();
let report_id = *report.metadata().id();

// Create an UploadResponse with a per-report error
let upload_response = UploadResponse::new(&[ReportUploadStatus::new(
report_id,
ReportError::ReportReplayed,
)]);
let response_body = upload_response.get_encoded().unwrap();

let mocked_upload = server
.mock(
"POST",
format!("/tasks/{}/reports", client.parameters.task_id).as_str(),
)
.match_header(CONTENT_TYPE.as_str(), UploadRequest::MEDIA_TYPE)
.with_status(200)
.with_header(CONTENT_TYPE.as_str(), UploadResponse::MEDIA_TYPE)
.with_body(response_body)
.expect(1)
.create_async()
.await;

// Upload should fail even though HTTP status is 200
let result = client.upload(true).await;
assert_matches!(
result,
Err(Error::Upload(statuses)) => {
assert_eq!(statuses.len(), 1);
assert_eq!(statuses[0].report_id(), report_id);
assert_eq!(statuses[0].error(), ReportError::ReportReplayed);
}
);

mocked_upload.assert_async().await;
}

#[tokio::test]
async fn upload_success_without_response_body() {
install_test_trace_subscriber();
initialize_rustls();
let mut server = mockito::Server::new_async().await;
let client = setup_client(&server, Prio3::new_count(2).unwrap()).await;

let mocked_upload = server
.mock(
"POST",
format!("/tasks/{}/reports", client.parameters.task_id).as_str(),
)
.match_header(CONTENT_TYPE.as_str(), UploadRequest::MEDIA_TYPE)
.with_status(200)
.expect(1)
.create_async()
.await;

// Upload should succeed when HTTP 200 is returned without a response body
client.upload(true).await.unwrap();

mocked_upload.assert_async().await;
}

#[tokio::test]
async fn upload_success_with_empty_response() {
install_test_trace_subscriber();
initialize_rustls();
let mut server = mockito::Server::new_async().await;
let client = setup_client(&server, Prio3::new_count(2).unwrap()).await;

// Empty UploadResponse (no errors)
let upload_response = UploadResponse::new(&[]);
let response_body = upload_response.get_encoded().unwrap();

let mocked_upload = server
.mock(
"POST",
format!("/tasks/{}/reports", client.parameters.task_id).as_str(),
)
.match_header(CONTENT_TYPE.as_str(), UploadRequest::MEDIA_TYPE)
.with_status(200)
.with_header(CONTENT_TYPE.as_str(), UploadResponse::MEDIA_TYPE)
.with_body(response_body)
.expect(1)
.create_async()
.await;

// Upload should succeed when HTTP 200 is returned with an empty UploadResponse
client.upload(true).await.unwrap();

mocked_upload.assert_async().await;
}
Loading