diff --git a/deps/spdm-rs b/deps/spdm-rs index 99020d72..fd5f0355 160000 --- a/deps/spdm-rs +++ b/deps/spdm-rs @@ -1 +1 @@ -Subproject commit 99020d72ac0910f4a7f89f8f305c529b603897e5 +Subproject commit fd5f0355cc7bcc6a0dddbb2fe58d80229da4ae5c diff --git a/doc/memory_usage_test.md b/doc/memory_usage_test.md index f03fc863..398dc2f2 100644 --- a/doc/memory_usage_test.md +++ b/doc/memory_usage_test.md @@ -255,12 +255,11 @@ echo "qom-set /objects/tdx0/ vsockport 1237" | nc -U /tmp/qmp-sock-dst-2 Wait all sessions complete pre-migration, and check the data logged in terminal for memory using status: -(example result) +(example result, migtd-dst) ```bash -INFO - MSK exchange completed -max stack usage: 118128 -max heap usage: 190585 +max stack usage: b3f38 +max heap usage: 140c07 ``` ### Current SPDM attestation memory data @@ -268,7 +267,7 @@ max heap usage: 190585 Current test result for spdm attestation are determined by destination migtd with policy v2 configuration. ```bash -Stack Size = 0x16_0000 +Stack Size = 0x10_0000 Heap Size = 0x12_0000 + 0x5_0000 * session_num ``` diff --git a/src/migtd/src/spdm/spdm_req.rs b/src/migtd/src/spdm/spdm_req.rs index d21f4362..fafcc06f 100644 --- a/src/migtd/src/spdm/spdm_req.rs +++ b/src/migtd/src/spdm/spdm_req.rs @@ -137,7 +137,7 @@ async fn send_and_receive_pub_key(spdm_requester: &mut RequesterContext) -> Spdm vendor_id[..VDM_MESSAGE_VENDOR_ID_LEN].copy_from_slice(&VDM_MESSAGE_VENDOR_ID); let vendor_id = VendorIDStruct { len: 4, vendor_id }; - let mut payload = [0u8; MAX_SPDM_VENDOR_DEFINED_PAYLOAD_SIZE]; + let mut payload = vec![0u8; MAX_SPDM_VENDOR_DEFINED_PAYLOAD_SIZE]; let mut writer = Writer::init(&mut payload); let mut cnt = 0; @@ -162,46 +162,39 @@ async fn send_and_receive_pub_key(spdm_requester: &mut RequesterContext) -> Spdm .extend_from_slice(my_pub_key.as_slice()) .ok_or(SPDM_STATUS_BUFFER_FULL)?; - let vdm_payload = VendorDefinedReqPayloadStruct { - req_length: cnt as u32, - vendor_defined_req_payload: payload, - }; - - spdm_requester.common.reset_buffer_via_request_code( - SpdmRequestResponseCode::SpdmRequestVendorDefinedRequest, - None, - ); - let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let mut writer = Writer::init(&mut send_buffer); - let request = SpdmMessage { - header: SpdmMessageHeader { - version: spdm_requester.common.negotiate_info.spdm_version_sel, - request_response_code: SpdmRequestResponseCode::SpdmRequestVendorDefinedRequest, - }, - payload: SpdmMessagePayload::SpdmVendorDefinedRequest(SpdmVendorDefinedRequestPayload { - standard_id: RegistryOrStandardsBodyID::IANA, - vendor_id, - req_payload: vdm_payload, - }), + let request_header = SpdmMessageHeader { + version: spdm_requester.common.negotiate_info.spdm_version_sel, + request_response_code: SpdmRequestResponseCode::SpdmRequestVendorDefinedRequest, }; - let used = request.spdm_encode(&mut spdm_requester.common, &mut writer)?; - - spdm_requester - .send_message(None, &send_buffer[..used], false) - .await?; + let request_payload = SpdmVdmRequestPayload { + standard_id: RegistryOrStandardsBodyID::IANA, + vendor_id, + req_length: cnt as u32, + req_payload: payload, + }; + let mut used = 0; + used += request_header + .encode(&mut writer) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + used += request_payload.spdm_encode(&mut spdm_requester.common, &mut writer)?; let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let receive_used = spdm_requester - .receive_message(None, &mut receive_buffer, false) + let response = spdm_requester + .send_spdm_vendor_defined_request_ex(None, &send_buffer[..used], &mut receive_buffer) .await?; - let vdm_payload = - spdm_requester.handle_spdm_vendor_defined_respond(None, &receive_buffer[..receive_used])?; - // Format checks and save the received public key + let mut reader = Reader::init(response); + let _response_header = + SpdmMessageHeader::read(&mut reader).ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; + let response_payload = + SpdmVdmResponsePayload::spdm_read(&mut spdm_requester.common, &mut reader) + .ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; + let mut reader = - Reader::init(&vdm_payload.vendor_defined_rsp_payload[..vdm_payload.rsp_length as usize]); + Reader::init(&response_payload.rsp_payload[..response_payload.rsp_length as usize]); let vdm_message = VdmMessage::read(&mut reader).ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; if vdm_message.major_version != VDM_MESSAGE_MAJOR_VERSION { error!( @@ -272,8 +265,7 @@ async fn send_and_receive_pub_key(spdm_requester: &mut RequesterContext) -> Spdm let vdm_pub_key_src_hash = digest_sha384(&send_buffer[..used]).map_err(|_| SPDM_STATUS_CRYPTO_ERROR)?; - let vdm_pub_key_dst_hash = - digest_sha384(&receive_buffer[..receive_used]).map_err(|_| SPDM_STATUS_CRYPTO_ERROR)?; + let vdm_pub_key_dst_hash = digest_sha384(response).map_err(|_| SPDM_STATUS_CRYPTO_ERROR)?; let mut transcript_before_key_exchange = ManagedVdmBuffer::default(); transcript_before_key_exchange .append_message(vdm_pub_key_src_hash.as_slice()) @@ -306,7 +298,7 @@ pub async fn send_and_receive_sdm_migration_attest_info( vendor_id[..VDM_MESSAGE_VENDOR_ID_LEN].copy_from_slice(&VDM_MESSAGE_VENDOR_ID); let vendor_id = VendorIDStruct { len: 4, vendor_id }; - let mut payload = [0u8; MAX_SPDM_VENDOR_DEFINED_PAYLOAD_SIZE]; + let mut payload = vec![0u8; MAX_SPDM_VENDOR_DEFINED_PAYLOAD_SIZE]; let mut writer = Writer::init(&mut payload); let mut cnt = 0; @@ -411,11 +403,6 @@ pub async fn send_and_receive_sdm_migration_attest_info( .extend_from_slice(&mig_policy_src_hash) .ok_or(SPDM_STATUS_BUFFER_FULL)?; - let vdm_payload = VendorDefinedReqPayloadStruct { - req_length: cnt as u32, - vendor_defined_req_payload: payload, - }; - spdm_requester.common.reset_buffer_via_request_code( SpdmRequestResponseCode::SpdmRequestVendorDefinedRequest, None, @@ -423,35 +410,37 @@ pub async fn send_and_receive_sdm_migration_attest_info( let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let mut writer = Writer::init(&mut send_buffer); - let request = SpdmMessage { - header: SpdmMessageHeader { - version: spdm_requester.common.negotiate_info.spdm_version_sel, - request_response_code: SpdmRequestResponseCode::SpdmRequestVendorDefinedRequest, - }, - payload: SpdmMessagePayload::SpdmVendorDefinedRequest(SpdmVendorDefinedRequestPayload { - standard_id: RegistryOrStandardsBodyID::IANA, - vendor_id, - req_payload: vdm_payload, - }), + let request_header = SpdmMessageHeader { + version: spdm_requester.common.negotiate_info.spdm_version_sel, + request_response_code: SpdmRequestResponseCode::SpdmRequestVendorDefinedRequest, }; - let used = request.spdm_encode(&mut spdm_requester.common, &mut writer)?; - - spdm_requester - .send_message(None, &send_buffer[..used], false) - .await?; + let request_payload = SpdmVdmRequestPayload { + standard_id: RegistryOrStandardsBodyID::IANA, + vendor_id, + req_length: cnt as u32, + req_payload: payload, + }; + let mut send_used = 0; + send_used += request_header + .encode(&mut writer) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + send_used += request_payload.spdm_encode(&mut spdm_requester.common, &mut writer)?; let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let receive_used = spdm_requester - .receive_message(None, &mut receive_buffer, false) + let response = spdm_requester + .send_spdm_vendor_defined_request_ex(None, &send_buffer[..send_used], &mut receive_buffer) .await?; - let vdm_payload = - spdm_requester.handle_spdm_vendor_defined_respond(None, &receive_buffer[..receive_used])?; - //Format checks - let reader = &mut Reader::init( - &vdm_payload.vendor_defined_rsp_payload[..vdm_payload.rsp_length as usize], - ); + let mut reader = Reader::init(response); + let _response_header = + SpdmMessageHeader::read(&mut reader).ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; + let response_payload = + SpdmVdmResponsePayload::spdm_read(&mut spdm_requester.common, &mut reader) + .ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; + + let reader = + &mut Reader::init(&response_payload.rsp_payload[..response_payload.rsp_length as usize]); let vdm_message = VdmMessage::read(reader).ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; if vdm_message.major_version != VDM_MESSAGE_MAJOR_VERSION { error!( @@ -591,9 +580,8 @@ pub async fn send_and_receive_sdm_migration_attest_info( } let vdm_attest_info_src_hash = - digest_sha384(&send_buffer[..used]).map_err(|_| SPDM_STATUS_CRYPTO_ERROR)?; - let vdm_attest_info_dst_hash = - digest_sha384(&receive_buffer[..receive_used]).map_err(|_| SPDM_STATUS_CRYPTO_ERROR)?; + digest_sha384(&send_buffer[..send_used]).map_err(|_| SPDM_STATUS_CRYPTO_ERROR)?; + let vdm_attest_info_dst_hash = digest_sha384(response).map_err(|_| SPDM_STATUS_CRYPTO_ERROR)?; let mut transcript_before_finish = ManagedVdmBuffer::default(); transcript_before_finish .append_message(vdm_attest_info_src_hash.as_slice()) @@ -622,7 +610,7 @@ async fn send_and_receive_sdm_exchange_migration_info( let mut exchange_information = exchange_info(mig_info, false)?; - let mut payload = [0u8; MAX_SPDM_VENDOR_DEFINED_PAYLOAD_SIZE]; + let mut payload = vec![0u8; MAX_SPDM_VENDOR_DEFINED_PAYLOAD_SIZE]; let mut writer = Writer::init(&mut payload); let mut cnt = 0; @@ -668,11 +656,6 @@ async fn send_and_receive_sdm_exchange_migration_info( .encode(&mut writer) .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; - let vdm_payload = VendorDefinedReqPayloadStruct { - req_length: cnt as u32, - vendor_defined_req_payload: payload, - }; - spdm_requester.common.reset_buffer_via_request_code( SpdmRequestResponseCode::SpdmRequestVendorDefinedRequest, None, @@ -680,34 +663,42 @@ async fn send_and_receive_sdm_exchange_migration_info( let mut send_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; let mut writer = Writer::init(&mut send_buffer); - let request = SpdmMessage { - header: SpdmMessageHeader { - version: spdm_requester.common.negotiate_info.spdm_version_sel, - request_response_code: SpdmRequestResponseCode::SpdmRequestVendorDefinedRequest, - }, - payload: SpdmMessagePayload::SpdmVendorDefinedRequest(SpdmVendorDefinedRequestPayload { - standard_id: RegistryOrStandardsBodyID::IANA, - vendor_id, - req_payload: vdm_payload, - }), - }; - let used = request.spdm_encode(&mut spdm_requester.common, &mut writer)?; - spdm_requester - .send_message(session_id, &send_buffer[..used], false) - .await?; + let request_header = SpdmMessageHeader { + version: spdm_requester.common.negotiate_info.spdm_version_sel, + request_response_code: SpdmRequestResponseCode::SpdmRequestVendorDefinedRequest, + }; + let request_payload = SpdmVdmRequestPayload { + standard_id: RegistryOrStandardsBodyID::IANA, + vendor_id, + req_length: cnt as u32, + req_payload: payload, + }; + let mut send_used = 0; + send_used += request_header + .encode(&mut writer) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + send_used += request_payload.spdm_encode(&mut spdm_requester.common, &mut writer)?; let mut receive_buffer = [0u8; config::MAX_SPDM_MSG_SIZE]; - let receive_used = spdm_requester - .receive_message(session_id, &mut receive_buffer, false) + let response = spdm_requester + .send_spdm_vendor_defined_request_ex( + session_id, + &send_buffer[..send_used], + &mut receive_buffer, + ) .await?; - let vdm_payload = spdm_requester - .handle_spdm_vendor_defined_respond(session_id, &receive_buffer[..receive_used])?; + // Format checks + let mut reader = Reader::init(response); + let _response_header = + SpdmMessageHeader::read(&mut reader).ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; + let response_payload = + SpdmVdmResponsePayload::spdm_read(&mut spdm_requester.common, &mut reader) + .ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; - let reader = &mut Reader::init( - &vdm_payload.vendor_defined_rsp_payload[..vdm_payload.rsp_length as usize], - ); + let reader = + &mut Reader::init(&response_payload.rsp_payload[..response_payload.rsp_length as usize]); let vdm_message = VdmMessage::read(reader).ok_or(SPDM_STATUS_INVALID_MSG_SIZE)?; if vdm_message.major_version != VDM_MESSAGE_MAJOR_VERSION { error!( diff --git a/src/migtd/src/spdm/spdm_vdm.rs b/src/migtd/src/spdm/spdm_vdm.rs index 191912b0..cb86b0e7 100644 --- a/src/migtd/src/spdm/spdm_vdm.rs +++ b/src/migtd/src/spdm/spdm_vdm.rs @@ -2,12 +2,15 @@ // // SPDX-License-Identifier: BSD-2-Clause-Patent +use alloc::vec::Vec; use codec::{enum_builder, Codec, Reader, Writer}; use crypto::hash::digest_sha384; +use spdmlib::common; use spdmlib::{ common::{ManagedVdmBuffer, SpdmCodec}, error::*, message::*, + protocol::{SpdmRequestCapabilityFlags, SpdmResponseCapabilityFlags, SpdmVersion}, responder::ResponderContext, }; @@ -137,6 +140,229 @@ impl Codec for VdmMessageElement { } } +// Define the VDM request and response payloads rather than reuse Spdm lib structures to avoid using large slices in stack. +#[derive(Debug, Clone)] +pub struct SpdmVdmRequestPayload { + pub standard_id: RegistryOrStandardsBodyID, + pub vendor_id: VendorIDStruct, + pub req_length: u32, + pub req_payload: Vec, +} + +impl SpdmCodec for SpdmVdmRequestPayload { + fn spdm_encode( + &self, + context: &mut common::SpdmContext, + bytes: &mut Writer, + ) -> Result { + let large_payload = context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion14 + && context + .negotiate_info + .rsp_capabilities_sel + .contains(SpdmResponseCapabilityFlags::LARGE_RESP_CAP) + && context + .negotiate_info + .req_capabilities_sel + .contains(SpdmRequestCapabilityFlags::LARGE_RESP_CAP); + let mut cnt = 0usize; + let param1 = if large_payload { + SpdmVdmFlags::USE_LARGE_PAYLOAD + } else { + SpdmVdmFlags::default() + }; + cnt += param1.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param1 + cnt += 0u8.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param2 + cnt += self + .standard_id + .encode(bytes) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; //Standard ID + cnt += self + .vendor_id + .encode(bytes) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + if large_payload { + if self.req_length as usize > self.req_payload.len() { + return Err(SPDM_STATUS_INVALID_MSG_FIELD); + } + cnt += 0u16.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // req_length + cnt += self + .req_length + .encode(bytes) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + for d in self.req_payload.iter().take(self.req_length as usize) { + cnt += d.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + } + } else { + if self.req_length > u16::MAX as u32 + || self.req_payload.len() < self.req_length as usize + { + return Err(SPDM_STATUS_INVALID_MSG_FIELD); + } + cnt += (self.req_length as u16) + .encode(bytes) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + for d in self.req_payload.iter().take(self.req_length as usize) { + cnt += d.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + } + } + Ok(cnt) + } + + fn spdm_read( + context: &mut common::SpdmContext, + r: &mut Reader, + ) -> Option { + let param1 = SpdmVdmFlags::read(r)?; // param1 + u8::read(r)?; // param2 + let large_payload = param1.contains(SpdmVdmFlags::USE_LARGE_PAYLOAD); + if large_payload + && !(context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion14 + && context + .negotiate_info + .rsp_capabilities_sel + .contains(SpdmResponseCapabilityFlags::LARGE_RESP_CAP) + && context + .negotiate_info + .req_capabilities_sel + .contains(SpdmRequestCapabilityFlags::LARGE_RESP_CAP)) + { + return None; + } + let standard_id = RegistryOrStandardsBodyID::read(r)?; // Standard ID + let vendor_id = VendorIDStruct::read(r)?; + let req_length = if large_payload { + let _ = u16::read(r)?; // rsp_length (reserved) + u32::read(r)? + } else { + let len = u16::read(r)?; // rsp_length + len as u32 + }; + let mut req_payload = Vec::with_capacity(req_length as usize); + for _ in 0..req_length { + let d = u8::read(r)?; + req_payload.push(d); + } + + Some(SpdmVdmRequestPayload { + standard_id, + vendor_id, + req_length, + req_payload, + }) + } +} + +#[derive(Debug, Clone)] +pub struct SpdmVdmResponsePayload { + pub standard_id: RegistryOrStandardsBodyID, + pub vendor_id: VendorIDStruct, + pub rsp_length: u32, + pub rsp_payload: Vec, +} + +impl SpdmCodec for SpdmVdmResponsePayload { + fn spdm_encode( + &self, + context: &mut common::SpdmContext, + bytes: &mut Writer, + ) -> Result { + let large_payload = context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion14 + && context + .negotiate_info + .rsp_capabilities_sel + .contains(SpdmResponseCapabilityFlags::LARGE_RESP_CAP) + && context + .negotiate_info + .req_capabilities_sel + .contains(SpdmRequestCapabilityFlags::LARGE_RESP_CAP); + let mut cnt = 0usize; + let param1 = if large_payload { + SpdmVdmFlags::USE_LARGE_PAYLOAD + } else { + SpdmVdmFlags::default() + }; + cnt += param1.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param1 + cnt += 0u8.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // param2 + cnt += self + .standard_id + .encode(bytes) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; //Standard ID + cnt += self + .vendor_id + .encode(bytes) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + if large_payload { + if self.rsp_length as usize > self.rsp_payload.len() { + return Err(SPDM_STATUS_INVALID_MSG_FIELD); + } + cnt += 0u16.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; // rsp_length + cnt += self + .rsp_length + .encode(bytes) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + for d in self.rsp_payload.iter().take(self.rsp_length as usize) { + cnt += d.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + } + } else { + if self.rsp_length > u16::MAX as u32 + || self.rsp_payload.len() < self.rsp_length as usize + { + return Err(SPDM_STATUS_INVALID_MSG_FIELD); + } + cnt += (self.rsp_length as u16) + .encode(bytes) + .map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + for d in self.rsp_payload.iter().take(self.rsp_length as usize) { + cnt += d.encode(bytes).map_err(|_| SPDM_STATUS_BUFFER_FULL)?; + } + } + Ok(cnt) + } + + fn spdm_read( + context: &mut common::SpdmContext, + r: &mut Reader, + ) -> Option { + let param1 = SpdmVdmFlags::read(r)?; // param1 + u8::read(r)?; // param2 + let large_payload = param1.contains(SpdmVdmFlags::USE_LARGE_PAYLOAD); + if large_payload + && !(context.negotiate_info.spdm_version_sel >= SpdmVersion::SpdmVersion14 + && context + .negotiate_info + .rsp_capabilities_sel + .contains(SpdmResponseCapabilityFlags::LARGE_RESP_CAP) + && context + .negotiate_info + .req_capabilities_sel + .contains(SpdmRequestCapabilityFlags::LARGE_RESP_CAP)) + { + return None; + } + let standard_id = RegistryOrStandardsBodyID::read(r)?; // Standard ID + let vendor_id = VendorIDStruct::read(r)?; + let rsp_length = if large_payload { + let _ = u16::read(r)?; // rsp_length (reserved) + u32::read(r)? + } else { + let len = u16::read(r)?; // rsp_length + len as u32 + }; + let mut rsp_payload = Vec::with_capacity(rsp_length as usize); + for _ in 0..rsp_length { + let d = u8::read(r)?; + rsp_payload.push(d); + } + + Some(SpdmVdmResponsePayload { + standard_id, + vendor_id, + rsp_length, + rsp_payload, + }) + } +} + pub fn migtd_vdm_msg_rsp_dispatcher_ex<'a>( responder_context: &mut ResponderContext, session_id: Option, @@ -163,18 +389,16 @@ pub fn migtd_vdm_msg_rsp_dispatcher_ex<'a>( return (Err(SPDM_STATUS_INVALID_MSG_FIELD), Some(&rsp_bytes[..used])); } - let vendor_defined_request_payload = - SpdmVendorDefinedRequestPayload::spdm_read(&mut responder_context.common, &mut reader); - if vendor_defined_request_payload.is_none() { + let req_payload = SpdmVdmRequestPayload::spdm_read(&mut responder_context.common, &mut reader); + if req_payload.is_none() { responder_context.write_spdm_error(SpdmErrorCode::SpdmErrorInvalidRequest, 0, &mut writer); let used = writer.used(); return (Err(SPDM_STATUS_INVALID_MSG_FIELD), Some(&rsp_bytes[..used])); } - let vendor_defined_request_payload = vendor_defined_request_payload.unwrap(); - let standard_id = vendor_defined_request_payload.standard_id; - let vendor_id = vendor_defined_request_payload.vendor_id; - let req_payload = vendor_defined_request_payload.req_payload; + let req_payload = req_payload.unwrap(); + let standard_id = req_payload.standard_id; + let vendor_id = req_payload.vendor_id; if standard_id != RegistryOrStandardsBodyID::IANA || vendor_id.len != VDM_MESSAGE_VENDOR_ID_LEN as u8 @@ -185,8 +409,7 @@ pub fn migtd_vdm_msg_rsp_dispatcher_ex<'a>( return (Err(SPDM_STATUS_INVALID_MSG_FIELD), Some(&rsp_bytes[..used])); } - let mut reader = - Reader::init(&req_payload.vendor_defined_req_payload[0..req_payload.req_length as usize]); + let mut reader = Reader::init(&req_payload.req_payload[0..req_payload.req_length as usize]); let vdm_request = if let Some(vdm_request) = VdmMessage::read(&mut reader) { vdm_request } else { @@ -195,67 +418,51 @@ pub fn migtd_vdm_msg_rsp_dispatcher_ex<'a>( return (Err(SPDM_STATUS_INVALID_MSG_SIZE), Some(&rsp_bytes[..used])); }; - let mut response = SpdmMessage { - header: SpdmMessageHeader { - version: responder_context.common.negotiate_info.spdm_version_sel, - request_response_code: SpdmRequestResponseCode::SpdmResponseVendorDefinedResponse, - }, - payload: SpdmMessagePayload::SpdmVendorDefinedResponse(SpdmVendorDefinedResponsePayload { - standard_id: RegistryOrStandardsBodyID::IANA, - vendor_id: VendorIDStruct { - len: VDM_MESSAGE_VENDOR_ID_LEN as u8, - vendor_id: [0u8; MAX_SPDM_VENDOR_DEFINED_VENDOR_ID_LEN], - }, - rsp_payload: VendorDefinedRspPayloadStruct { - rsp_length: 0, - vendor_defined_rsp_payload: [0u8; MAX_SPDM_VENDOR_DEFINED_PAYLOAD_SIZE], - }, - }), + let response_header = SpdmMessageHeader { + version: responder_context.common.negotiate_info.spdm_version_sel, + request_response_code: SpdmRequestResponseCode::SpdmResponseVendorDefinedResponse, }; - let vdm_payload = match &mut response.payload { - SpdmMessagePayload::SpdmVendorDefinedResponse(vdm_payload) => vdm_payload, - _ => { - responder_context.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, &mut writer); - let used = writer.used(); - return ( - Err(SPDM_STATUS_INVALID_STATE_LOCAL), - Some(&rsp_bytes[..used]), - ); - } + let mut vdm_rsp_payload = SpdmVdmResponsePayload { + standard_id: RegistryOrStandardsBodyID::IANA, + vendor_id: VendorIDStruct { + len: VDM_MESSAGE_VENDOR_ID_LEN as u8, + vendor_id: [0u8; MAX_SPDM_VENDOR_DEFINED_VENDOR_ID_LEN], + }, + rsp_length: 0, + rsp_payload: vec![0u8; MAX_SPDM_VENDOR_DEFINED_PAYLOAD_SIZE], }; // Patch the vendor id field - vdm_payload.vendor_id.vendor_id[..VDM_MESSAGE_VENDOR_ID_LEN] + vdm_rsp_payload.vendor_id.vendor_id[..VDM_MESSAGE_VENDOR_ID_LEN] .copy_from_slice(&VDM_MESSAGE_VENDOR_ID); //Patch the response payload - let rsp_payload = &mut vdm_payload.rsp_payload; let vdm_payload_size = match vdm_request.op_code { VdmMessageOpCode::ExchangePubKeyReq => handle_exchange_pub_key_req( responder_context, &vdm_request, &mut reader, - &mut rsp_payload.vendor_defined_rsp_payload, + &mut vdm_rsp_payload.rsp_payload, ), VdmMessageOpCode::ExchangeMigrationAttestInfoReq => handle_exchange_mig_attest_info_req( responder_context, session_id, &vdm_request, &mut reader, - &mut rsp_payload.vendor_defined_rsp_payload, + &mut vdm_rsp_payload.rsp_payload, ), VdmMessageOpCode::ExchangeMigrationInfoReq => handle_exchange_mig_info_req( responder_context, session_id, &vdm_request, &mut reader, - &mut rsp_payload.vendor_defined_rsp_payload, + &mut vdm_rsp_payload.rsp_payload, ), _ => Err(SPDM_STATUS_INVALID_MSG_FIELD), }; if let Ok(vdm_payload_size) = vdm_payload_size { - rsp_payload.rsp_length = vdm_payload_size as u32; + vdm_rsp_payload.rsp_length = vdm_payload_size as u32; } else { responder_context.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, &mut writer); let used = writer.used(); @@ -265,7 +472,17 @@ pub fn migtd_vdm_msg_rsp_dispatcher_ex<'a>( ); }; - let res = response.spdm_encode(&mut responder_context.common, &mut writer); + let res = response_header.encode(&mut writer); + if res.is_err() { + responder_context.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, &mut writer); + let used = writer.used(); + return ( + Err(SPDM_STATUS_INVALID_STATE_LOCAL), + Some(&rsp_bytes[..used]), + ); + } + + let res = vdm_rsp_payload.spdm_encode(&mut responder_context.common, &mut writer); if res.is_err() { responder_context.write_spdm_error(SpdmErrorCode::SpdmErrorUnspecified, 0, &mut writer); let used = writer.used();