diff --git a/Cargo.toml b/Cargo.toml index a9675ff..9e6527a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,11 @@ categories = ["embedded", "no-std"] [dependencies] mctp-estack = { git = "https://github.com/OpenPRoT/mctp-rs.git", branch = "sync-features", default-features = false, features = ["log"] } mctp = { git = "https://github.com/OpenPRoT/mctp-rs.git", branch = "sync-features", default-features = false } +zerocopy = {version = "0.8.17", features = ["derive"]} [dev-dependencies] +mctp = { git = "https://github.com/OpenPRoT/mctp-rs.git", branch = "sync-features" } +standalone = { path = "standalone" } [package.metadata.docs.rs] all-features = true diff --git a/examples/echo/main.rs b/examples/echo/main.rs new file mode 100644 index 0000000..e1b39e5 --- /dev/null +++ b/examples/echo/main.rs @@ -0,0 +1,65 @@ +// Copyright 2025 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Example that listens for a request and echoes the payload in the response. +//! +//! Uses the standalone std implementation for the Stack and attaches to a specified serial port. +//! (Use a tool like _socat_ to attach to the linux MCTP stack through PTYs) +//! +//! Errors after the specified timeout. + +const MSG_TYPE: MsgType = MsgType(1); +const OWN_EID: Eid = Eid(8); +const TIMEOUT_SECS: u64 = 10; +const TTY_PATH: &str = "pts1"; + +use std::{fs::File, thread::spawn, time::Duration}; + +use mctp::{Eid, Listener, MsgType, RespChannel}; +use standalone::{ + Stack, + serial_sender::IoSerialSender, + util::{inbound_loop, update_loop}, +}; + +fn main() { + let serial = File::options() + .write(true) + .read(true) + .open(TTY_PATH) + .unwrap(); + + let serial_sender = IoSerialSender::new(serial.try_clone().unwrap()); + + let mut stack = Stack::new(serial_sender); + + stack.set_eid(OWN_EID).unwrap(); + + let update_stack = stack.clone(); + spawn(move || update_loop(update_stack)); + + let driver_stack = stack.clone(); + spawn(move || inbound_loop(driver_stack, serial)); + + let mut listener = stack + .listener(MSG_TYPE, Some(Duration::from_secs(TIMEOUT_SECS))) + .unwrap(); + + let mut buf = [0; 256]; + let (_, _, msg, mut rsp) = listener.recv(&mut buf).unwrap(); + + println!("Got message: {:#x?}", msg); + + rsp.send(msg).unwrap(); +} diff --git a/examples/mctp-control-endpoint/main.rs b/examples/mctp-control-endpoint/main.rs new file mode 100644 index 0000000..8ca8a65 --- /dev/null +++ b/examples/mctp-control-endpoint/main.rs @@ -0,0 +1,179 @@ +// Copyright 2025 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Example that listens for a Set Endpoint ID MCTP Control message, sets the ID and responds to the request. +//! +//! Uses the standalone std implementation for the Stack and attaches to a specified serial port. +//! (Use a tool like _socat_ to attach to the linux MCTP stack through PTYs) +//! +//! Errors after the specified timeout. + +const TIMEOUT_SECS: u64 = 10; +const TTY_PATH: &str = "pts1"; + +use std::{fs::File, thread::spawn, time::Duration}; + +use mctp::{Listener, RespChannel}; +use standalone::{ + Response, Stack, + serial_sender::IoSerialSender, + util::{inbound_loop, update_loop}, +}; + +use mctp_lib::mctp_control::{ + CompletionCode, MctpControlHeader, SetEndpointIDOperation, SetEndpointIdRequest, + SetEndpointIdResponse, codec::MctpCodec, +}; +use mctp_lib::{Sender, mctp_control::MctpControlMessage}; + +fn main() { + let serial = File::options() + .write(true) + .read(true) + .open(TTY_PATH) + .unwrap(); + + let serial_sender = IoSerialSender::new(serial.try_clone().unwrap()); + + let mut stack = Stack::new(serial_sender); + + let update_stack = stack.clone(); + spawn(move || update_loop(update_stack)); + + let driver_stack = stack.clone(); + spawn(move || inbound_loop(driver_stack, serial)); + + // MCTP Control Endpoint flow start + + let mut listener = stack + .listener( + mctp::MCTP_TYPE_CONTROL, + Some(Duration::from_secs(TIMEOUT_SECS)), + ) + .unwrap(); + + let mut buf = [0; 256]; + let (_, _, msg, rsp) = listener.recv(&mut buf).unwrap(); + + let ctrl_msg = MctpControlMessage::decode(msg).unwrap(); + + if !ctrl_msg.control_header.request { + panic!("Got a MCTP Control response while expecting a request"); + } + + match ctrl_msg.control_header.command_code { + mctp_lib::mctp_control::CommandCode::SetEndpointID => { + handle_set_endpoint_id(&ctrl_msg, rsp, &mut stack) + } + _ => unimplemented!(), + } +} + +/// Handles a Set Endpoint ID command and responds to the request +/// +/// The message buffer is expected to contain the Set Endpoint ID Message (Spec v1.3.3 Table 14). +fn handle_set_endpoint_id( + msg: &MctpControlMessage, + mut rsp: Response, + stack: &mut Stack, +) { + let Ok(set_eid_msg) = SetEndpointIdRequest::decode(msg.message_body) else { + let mut rsp_buf = [0; 32]; + let rsp_msg = set_eid_error( + msg.control_header.clone(), + CompletionCode::ErrorInvalidData, + &mut rsp_buf, + ); + rsp.send(rsp_msg).unwrap(); + return; + }; + + let eid = match set_eid_msg.0 { + SetEndpointIDOperation::SetEid(eid) => eid, // We always accept for simplicity here + SetEndpointIDOperation::ForceEid(eid) => eid, + SetEndpointIDOperation::ResetEid => { + let mut rsp_buf = [0; 32]; + let rsp_msg = set_eid_error( + msg.control_header.clone(), + CompletionCode::ErrorInvalidData, + &mut rsp_buf, + ); + rsp.send(rsp_msg).unwrap(); + return; + } + SetEndpointIDOperation::SetDiscoveredFlag => { + let mut rsp_buf = [0; 32]; + let rsp_msg = set_eid_error( + msg.control_header.clone(), + CompletionCode::ErrorInvalidData, + &mut rsp_buf, + ); + rsp.send(rsp_msg).unwrap(); + return; + } + }; + + if stack.set_eid(eid).is_err() { + let mut rsp_buf = [0; 32]; + let rsp_msg = set_eid_error( + msg.control_header.clone(), + CompletionCode::Error, + &mut rsp_buf, + ); + rsp.send(rsp_msg).unwrap(); + return; + } + println!("Assigned new EID {eid}"); + + let mut rsp_buf = [0; 32]; + let mut set_eid_resp = [0; 4]; + assert!( + SetEndpointIdResponse::new( + CompletionCode::Error, + mctp_lib::mctp_control::EidAssignmentStatus::Accepted, + mctp_lib::mctp_control::EidAllocationStatus::NoEidPoolUsed, + eid, + 0, + ) + .encode(&mut set_eid_resp) + .is_ok_and(|n| n == 4) + ); + let mut header = msg.control_header.clone(); + header.request = false; + let size = MctpControlMessage::new(header, &set_eid_resp) + .encode(&mut rsp_buf) + .unwrap(); + rsp.send(&rsp_buf[..size]).unwrap(); +} + +/// Formats a Set EID error response with given header and completion code into buf +fn set_eid_error(mut header: MctpControlHeader, code: CompletionCode, buf: &mut [u8]) -> &[u8] { + let mut err_resp = [0; 4]; + assert!( + SetEndpointIdResponse::new( + code, + mctp_lib::mctp_control::EidAssignmentStatus::Rejected, + mctp_lib::mctp_control::EidAllocationStatus::NoEidPoolUsed, + mctp::Eid(0), + 0, + ) + .encode(&mut err_resp) + .is_ok_and(|n| n == 4) + ); + header.request = false; + let size = MctpControlMessage::new(header, &err_resp) + .encode(buf) + .unwrap(); + &buf[..size] +} diff --git a/src/lib.rs b/src/lib.rs index f10bcfa..39e281e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ //! parts of it. #![cfg_attr(not(test), no_std)] #![deny(unsafe_code)] -#![deny(missing_docs)] +// #![deny(missing_docs)] #![deny(clippy::missing_panics_doc)] #![deny(clippy::panic)] #![deny(clippy::panicking_overflow_checks)] @@ -27,7 +27,11 @@ use mctp::{Eid, Error, MsgIC, MsgType, Result, Tag}; use mctp_estack::fragment::Fragmenter; -pub use mctp_estack::*; +pub use mctp_estack::{ + AppCookie, MctpMessage, Stack, TIMEOUT_INTERVAL, Vec, config, fragment, i2c, serial, usb, +}; + +pub mod mctp_control; #[derive(Debug)] struct ReqHandle { @@ -94,15 +98,20 @@ impl /// Provide an incoming packet to the router. /// /// This expects a single MCTP packet, without a transport binding header. - pub fn inbound(&mut self, pkt: &[u8]) -> Result<()> { + /// + /// Returns `Ok(Some(AppCookie))` for a associated listener or request, + /// or `Ok(None)` if the message was discarded. + pub fn inbound(&mut self, pkt: &[u8]) -> Result> { let own_eid = self.stack.eid(); let Some(mut msg) = self.stack.receive(pkt)? else { - return Ok(()); + return Ok(None); }; - if msg.dest != own_eid { - // Drop messages if eid does not match (for now) - return Ok(()); + if msg.dest != own_eid && msg.dest != Eid(0) { + // Drop messages if eid does not match (for now). + // EID 0 messages are used for physical addressing + // and will thus be processed. + return Ok(None); } match msg.tag { @@ -113,7 +122,7 @@ impl .is_some_and(|i| self.requests.get(i).is_some_and(|r| r.is_some())) { msg.retain(); - return Ok(()); + return Ok(Some(cookie)); } // In this case an unowned message not associated with a request was received. // This might happen if this endpoint was intended to route the packet to a different @@ -124,16 +133,17 @@ impl // check for matching listeners and retain with cookie for i in 0..self.listeners.len() { if self.listeners.get(i).ok_or(Error::InternalError)? == &Some(msg.typ) { - msg.set_cookie(Some(Self::listener_cookie_from_index(i))); + let cookie = Some(Self::listener_cookie_from_index(i)); + msg.set_cookie(cookie); msg.retain(); - return Ok(()); + return Ok(cookie); } } } } // Return Ok(()) even if a message has been discarded - Ok(()) + Ok(None) } /// Allocate a new request "_Handle_" @@ -221,12 +231,14 @@ impl Some(cookie), )?; - self.sender.send_vectored(frag, bufs) + self.sender.send_vectored(eid, frag, bufs) } /// Receive a message associated with a [`AppCookie`] /// /// Returns `None` when no message is available for the listener/request. + /// + /// The message can be retained and received at a later point again (see [MctpMessage::retain()]). pub fn recv(&mut self, cookie: AppCookie) -> Option> { self.stack.get_deferred_bycookie(&[cookie]) } @@ -326,7 +338,8 @@ impl /// Implemented by a transport binding for sending packets. pub trait Sender { /// Send a packet fragmented by `fragmenter` with the payload `payload` - fn send_vectored(&mut self, fragmenter: Fragmenter, payload: &[&[u8]]) -> Result; + fn send_vectored(&mut self, eid: Eid, fragmenter: Fragmenter, payload: &[&[u8]]) + -> Result; /// Get the MTU of a MCTP packet fragment (without transport headers) fn get_mtu(&self) -> usize; } @@ -345,6 +358,7 @@ mod test { impl Sender for DoNothingSender { fn send_vectored( &mut self, + _eid: Eid, fragmenter: mctp_estack::fragment::Fragmenter, payload: &[&[u8]], ) -> core::result::Result { @@ -364,6 +378,7 @@ mod test { impl Sender for BufferSender<'_, MTU> { fn send_vectored( &mut self, + _eid: Eid, mut fragmenter: mctp_estack::fragment::Fragmenter, payload: &[&[u8]], ) -> core::result::Result { diff --git a/src/mctp_control/codec.rs b/src/mctp_control/codec.rs new file mode 100644 index 0000000..eb558ed --- /dev/null +++ b/src/mctp_control/codec.rs @@ -0,0 +1,71 @@ +// Licensed under the Apache-2.0 license + +use zerocopy::{FromBytes, Immutable, IntoBytes}; + +//As of DSP0235 1.3.3 Line 1823: "11.6 MCTP control message transmission unit size" +const MCTP_CONTROL_MTU: usize = 64; + +#[allow(dead_code)] +#[derive(Debug, PartialEq)] +pub enum MctpCodecError { + BufferTooShort, + Unsupported, + InvalidData, + InternalError, + UnsupportedBufferSize, +} + +/// A trait for encoding and decoding MCTP (Management Component Transport Protocol) messages. +/// +/// This trait provides methods for encoding an MCTP message into a byte buffer +/// and decoding an MCTP message from a byte buffer. Implementers of this trait +/// must also implement the `Debug` trait and be `Sized`. +#[allow(dead_code)] +pub trait MctpCodec<'a>: core::fmt::Debug + Sized { + /// Encodes the MCTP message into the provided byte buffer. + /// + /// # Arguments + /// + /// * `buffer` - A mutable reference to a byte slice where the encoded message will be stored. + /// + /// # Returns + /// + /// A `Result` containing the size of the encoded message on success, or an `MctpCodecError` on failure. + fn encode(&self, buffer: &mut [u8]) -> Result; + + /// Decodes an MCTP message from the provided byte buffer. + /// + /// # Arguments + /// + /// * `buffer` - A reference to a byte slice containing the encoded message. + /// + /// # Returns + /// + /// A `Result` containing the decoded message on success, or an `MctpCodecError` on failure. + fn decode(buffer: &'a [u8]) -> Result; + + /// Maximum supported size of MCTP message in bytes. + /// + /// Defaults to `core::mem::size_of::()` for the implementing type. + const MCTP_CODEC_MIN_SIZE: usize = core::mem::size_of::(); +} + +// Default implementation of MctpCodec for types that can leverage zerocopy. +// TODO: can we generalize this to use sub-struct encodes when possible? + +impl MctpCodec<'_> for T +where + T: core::fmt::Debug + Sized + FromBytes + IntoBytes + Immutable, +{ + fn encode(&self, buffer: &mut [u8]) -> Result { + self.write_to_prefix(buffer) + .map_err(|_| MctpCodecError::BufferTooShort) + .map(|_| Self::MCTP_CODEC_MIN_SIZE) + } + + fn decode(buffer: &[u8]) -> Result { + Ok(Self::read_from_prefix(buffer) + .map_err(|_| MctpCodecError::BufferTooShort)? + .0) + } +} diff --git a/src/mctp_control/mod.rs b/src/mctp_control/mod.rs new file mode 100644 index 0000000..afa12bb --- /dev/null +++ b/src/mctp_control/mod.rs @@ -0,0 +1,1689 @@ +// Copyright 2025 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![allow(unused)] + +use mctp::{Eid, Error}; + +pub mod codec; +use crate::mctp_control::codec::{MctpCodec, MctpCodecError}; + +/// A `Result` with a MCTP control completion code as error. +pub type ControlResult = core::result::Result; + +/// MCTP control message completion code. +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] + +pub enum CompletionCode { + Success, + Error, + ErrorInvalidData, + ErrorInvalidLength, + ErrorNotReady, + ErrorUnsupportedCmd, + /// 0x80-0xff + /// Command-specific completion code with a custom value. + /// + /// This variant represents completion codes that are specific to individual + /// MCTP control commands and carries the raw completion code value. + CommandSpecific(u8), + Other(u8), +} + +impl From for CompletionCode { + fn from(value: u8) -> Self { + use CompletionCode::*; + match value { + 0x00 => Success, + 0x01 => Error, + 0x02 => ErrorInvalidData, + 0x03 => ErrorInvalidLength, + 0x04 => ErrorNotReady, + 0x05 => ErrorUnsupportedCmd, + 0x80..=0xff => CommandSpecific(value), + _ => Other(value), + } + } +} + +impl From for u8 { + fn from(cc: CompletionCode) -> Self { + use CompletionCode::*; + match cc { + Success => 0x00, + Error => 0x01, + ErrorInvalidData => 0x02, + ErrorInvalidLength => 0x03, + ErrorNotReady => 0x04, + ErrorUnsupportedCmd => 0x05, + CommandSpecific(v) | Other(v) => v, + } + } +} + +/// MCTP control command code. +#[allow(missing_docs)] +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] +pub enum CommandCode { + SetEndpointID, + GetEndpointID, + GetEndpointUUID, + GetMCTPVersionSupport, + GetMessageTypeSupport, + GetVendorDefinedMessageSupport, + ResolveEndpointID, + AllocateEndpointIDs, + RoutingInformationUpdate, + GetRoutingTableEntries, + PrepareforEndpointDiscovery, + DiscoveryNotify, + QueryHop, + ResolveUUID, + QueryRateRimit, + RequestTXRateLimit, + UpdateRateLimit, + QuerySupportedInterfaces, + TransportSpecific(u8), + Unknown(u8), +} + +impl From for CommandCode { + fn from(value: u8) -> Self { + use CommandCode::*; + match value { + 0x01 => SetEndpointID, + 0x02 => GetEndpointID, + 0x03 => GetEndpointUUID, + 0x04 => GetMCTPVersionSupport, + 0x05 => GetMessageTypeSupport, + 0x06 => GetVendorDefinedMessageSupport, + 0x07 => ResolveEndpointID, + 0x08 => AllocateEndpointIDs, + 0x09 => RoutingInformationUpdate, + 0x0A => GetRoutingTableEntries, + 0x0B => PrepareforEndpointDiscovery, + 0x0D => DiscoveryNotify, + 0x0F => QueryHop, + 0x10 => ResolveUUID, + 0x11 => QueryRateRimit, + 0x12 => RequestTXRateLimit, + 0x13 => UpdateRateLimit, + 0x14 => QuerySupportedInterfaces, + 0xf0..=0xff => TransportSpecific(value), + _ => Unknown(value), + } + } +} + +impl From for u8 { + fn from(cc: CommandCode) -> Self { + use CommandCode::*; + match cc { + SetEndpointID => 0x01, + GetEndpointID => 0x02, + GetEndpointUUID => 0x03, + GetMCTPVersionSupport => 0x04, + GetMessageTypeSupport => 0x05, + GetVendorDefinedMessageSupport => 0x06, + ResolveEndpointID => 0x07, + AllocateEndpointIDs => 0x08, + RoutingInformationUpdate => 0x09, + GetRoutingTableEntries => 0x0A, + PrepareforEndpointDiscovery => 0x0B, + DiscoveryNotify => 0x0D, + QueryHop => 0x0F, + ResolveUUID => 0x10, + QueryRateRimit => 0x11, + RequestTXRateLimit => 0x12, + UpdateRateLimit => 0x13, + QuerySupportedInterfaces => 0x14, + TransportSpecific(v) | Unknown(v) => v, + } + } +} + +#[derive(PartialEq, Eq, Clone, Hash, Debug)] +pub struct MctpControlHeader { + pub request: bool, + pub datagram: bool, + pub instance_id: u8, + pub command_code: CommandCode, +} + +impl MctpControlHeader { + pub fn new(request: bool, datagram: bool, instance_id: u8, command_code: CommandCode) -> Self { + Self { + request, + datagram, + instance_id, + command_code, + } + } +} + +impl MctpCodec<'_> for MctpControlHeader { + fn encode(&self, buffer: &mut [u8]) -> Result { + if buffer.len() < MctpControlHeader::MCTP_CODEC_MIN_SIZE { + return Err(MctpCodecError::BufferTooShort); + } + + *(buffer.get_mut(0).ok_or(MctpCodecError::InternalError)?) = + (((self.request as u8) << 7) | ((self.datagram as u8) << 6) | self.instance_id); + *(buffer.get_mut(1).ok_or(MctpCodecError::InternalError)?) = self.command_code.into(); + Ok(MctpControlHeader::MCTP_CODEC_MIN_SIZE) + } + + fn decode(buffer: &[u8]) -> Result { + if buffer.len() < MctpControlHeader::MCTP_CODEC_MIN_SIZE { + return Err(MctpCodecError::InvalidData); + } + + let request: bool = (buffer.first().ok_or(MctpCodecError::InvalidData)? & 0b1000_0000) != 0; + let datagram: bool = + (buffer.first().ok_or(MctpCodecError::InvalidData)? & 0b0100_0000) != 0; + let instance_id: u8 = (buffer.first().ok_or(MctpCodecError::InvalidData)? & 0b0011_1111); + let command_code: CommandCode = (*buffer.get(1).ok_or(MctpCodecError::InvalidData)?).into(); + + Ok(Self { + request, + datagram, + instance_id, + command_code, + }) + } + + const MCTP_CODEC_MIN_SIZE: usize = 2; +} + +#[derive(PartialEq, Eq, Clone, Hash, Debug)] +pub struct MctpControlMessage<'a> { + pub control_header: MctpControlHeader, + pub message_body: &'a [u8], +} + +impl<'a> MctpControlMessage<'a> { + pub fn new(control_header: MctpControlHeader, message_body: &'a [u8]) -> Self { + Self { + control_header, + message_body, + } + } +} + +impl<'a> MctpCodec<'a> for MctpControlMessage<'a> { + fn encode(&self, buffer: &mut [u8]) -> Result { + if ((self.message_body.len() + MctpControlHeader::MCTP_CODEC_MIN_SIZE) > buffer.len()) { + return Err(MctpCodecError::BufferTooShort); + } + + let header_size = MctpControlHeader::MCTP_CODEC_MIN_SIZE; + let header_buffer = buffer + .get_mut(..header_size) + .ok_or(MctpCodecError::InvalidData)?; + self.control_header.encode(header_buffer)?; + + buffer + .get_mut(header_size..header_size + self.message_body.len()) + .ok_or(MctpCodecError::InvalidData)? + .copy_from_slice(self.message_body); + + Ok(header_size + self.message_body.len()) + } + + fn decode(buffer: &'a [u8]) -> Result { + if (buffer.len() < MctpControlMessage::MCTP_CODEC_MIN_SIZE) { + return Err(MctpCodecError::InvalidData); + } + + let control_header: MctpControlHeader = + MctpControlHeader::decode(buffer.get(..2).ok_or(MctpCodecError::InternalError)?) + .map_err(|_| MctpCodecError::InvalidData)?; + let message_body = &buffer.get(2..).ok_or(MctpCodecError::InternalError)?; + + Ok(Self { + control_header, + message_body, + }) + } + + const MCTP_CODEC_MIN_SIZE: usize = 2; +} + +#[derive(PartialEq, Eq, Clone, Hash, Debug)] +pub enum SetEndpointIDOperation { + SetEid(Eid), + ForceEid(Eid), + ResetEid, + SetDiscoveredFlag, +} + +impl TryFrom<(u8, Eid)> for SetEndpointIDOperation { + type Error = CompletionCode; + + fn try_from((value, eid): (u8, Eid)) -> Result { + let operation: u8 = value & 0b0000_0011; + match operation { + 0x00 => Ok(SetEndpointIDOperation::SetEid(eid)), + 0x01 => Ok(SetEndpointIDOperation::ForceEid(eid)), + 0x02 => Ok(SetEndpointIDOperation::ResetEid), + 0x03 => Ok(SetEndpointIDOperation::SetDiscoveredFlag), + _ => Err(CompletionCode::ErrorInvalidData), + } + } +} + +impl From for (u8, Eid) { + fn from(operation: SetEndpointIDOperation) -> Self { + //TODO: Ok to use unwrap here? + let dummy: Eid = Eid::new_normal(8).unwrap(); + {}; + match operation { + SetEndpointIDOperation::SetEid(eid) => (0x00, eid), + SetEndpointIDOperation::ForceEid(eid) => (0x01, eid), + SetEndpointIDOperation::ResetEid => (0x02, dummy), + SetEndpointIDOperation::SetDiscoveredFlag => (0x03, dummy), + } + } +} + +impl From for (SetEndpointIdRequest) { + fn from(operation: SetEndpointIDOperation) -> SetEndpointIdRequest { + SetEndpointIdRequest(operation) + } +} + +impl From for (SetEndpointIDOperation) { + fn from(endpint_id_request: SetEndpointIdRequest) -> SetEndpointIDOperation { + endpint_id_request.0 + } +} + +#[derive(PartialEq, Eq, Clone, Hash, Debug)] +pub struct SetEndpointIdRequest(pub SetEndpointIDOperation); + +impl SetEndpointIdRequest { + pub fn new(operation: SetEndpointIDOperation) -> Self { + Self(operation) + } +} + +impl MctpCodec<'_> for SetEndpointIdRequest { + fn encode(&self, buffer: &mut [u8]) -> Result { + if buffer.len() < 2 { + return Err(MctpCodecError::BufferTooShort); + } + + let (op, eid) = self.0.clone().into(); + *buffer.first_mut().ok_or(MctpCodecError::InternalError)? = op; + *buffer.get_mut(1).ok_or(MctpCodecError::InternalError)? = eid.0; + Ok(SetEndpointIdRequest::MCTP_CODEC_MIN_SIZE) + } + + fn decode(buffer: &[u8]) -> Result { + if buffer.len() < SetEndpointIdRequest::MCTP_CODEC_MIN_SIZE { + return Err(MctpCodecError::BufferTooShort); + } + let op = *buffer.first().ok_or(MctpCodecError::InvalidData)?; + let eid = *buffer.get(1).ok_or(MctpCodecError::InvalidData)?; + let eid = Eid::new_normal(eid).map_err(|_| MctpCodecError::InvalidData)?; + let operation = + SetEndpointIDOperation::try_from((op, eid)).map_err(|_| MctpCodecError::InvalidData)?; + Ok(SetEndpointIdRequest(operation)) + } + + const MCTP_CODEC_MIN_SIZE: usize = 2; +} + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub enum EidAssignmentStatus { + Accepted, + Rejected, +} + +impl TryFrom for EidAssignmentStatus { + type Error = CompletionCode; + + fn try_from(value: u8) -> Result { + let status: u8 = (value >> 4) & 0b0000_0011; + match status { + 0x00 => Ok(EidAssignmentStatus::Accepted), + 0x01 => Ok(EidAssignmentStatus::Rejected), + _ => Err(CompletionCode::ErrorInvalidData), + } + } +} + +impl From for u8 { + fn from(status: EidAssignmentStatus) -> Self { + match status { + EidAssignmentStatus::Accepted => 0x00 << 4, + EidAssignmentStatus::Rejected => 0x01 << 4, + } + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub enum EidAllocationStatus { + NoEidPoolUsed, + EidPoolAllcotationRequired, + EidPoolAllcotationEstablished, +} + +impl TryFrom for EidAllocationStatus { + type Error = CompletionCode; + + fn try_from(value: u8) -> Result { + let status: u8 = value & 0b0000_0011; + match status { + 0x00 => Ok(EidAllocationStatus::NoEidPoolUsed), + 0x01 => Ok(EidAllocationStatus::EidPoolAllcotationRequired), + 0x02 => Ok(EidAllocationStatus::EidPoolAllcotationEstablished), + _ => Err(CompletionCode::ErrorInvalidData), + } + } +} + +impl From for u8 { + fn from(status: EidAllocationStatus) -> Self { + match status { + EidAllocationStatus::NoEidPoolUsed => 0x00, + EidAllocationStatus::EidPoolAllcotationRequired => 0x01, + EidAllocationStatus::EidPoolAllcotationEstablished => 0x02, + } + } +} + +#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] +pub struct SetEndpointIdResponse { + pub completion_code: CompletionCode, + + pub eid_assignment_status: EidAssignmentStatus, + + pub eid_allocation_status: EidAllocationStatus, + + pub eid_setting: Eid, + + pub eid_pool_size: u8, +} + +impl SetEndpointIdResponse { + pub fn new( + completion_code: CompletionCode, + eid_assignment_status: EidAssignmentStatus, + eid_allocation_status: EidAllocationStatus, + eid_setting: Eid, + eid_pool_size: u8, + ) -> Self { + Self { + completion_code, + eid_assignment_status, + eid_allocation_status, + eid_setting, + eid_pool_size, + } + } + + /// Creates a new `SetEndpointIdResponse` with an error completion code. + /// + /// # Panics + /// + /// Panics if the completion code is not one of the valid error codes: + /// `Error`, `ErrorInvalidData`, `ErrorInvalidLength`, `ErrorNotReady`, or `ErrorUnsupportedCmd`. + pub const fn new_err(completion_code: CompletionCode) -> Self { + assert!( + matches!( + completion_code, + CompletionCode::Error + | CompletionCode::ErrorInvalidData + | CompletionCode::ErrorInvalidLength + | CompletionCode::ErrorNotReady + | CompletionCode::ErrorUnsupportedCmd + ), + "Completion code must be an error code" + ); + Self { + completion_code, + eid_assignment_status: EidAssignmentStatus::Rejected, + eid_allocation_status: EidAllocationStatus::NoEidPoolUsed, + eid_setting: Eid(0), + eid_pool_size: 0, + } + } +} + +impl MctpCodec<'_> for SetEndpointIdResponse { + fn encode(&self, buffer: &mut [u8]) -> Result { + if buffer.len() < SetEndpointIdResponse::MCTP_CODEC_MIN_SIZE { + return Err(MctpCodecError::BufferTooShort); + } + + *buffer.get_mut(0).ok_or(MctpCodecError::InternalError)? = self.completion_code.into(); + *buffer.get_mut(1).ok_or(MctpCodecError::InternalError)? = + u8::from(self.eid_assignment_status) | u8::from(self.eid_allocation_status); + *buffer.get_mut(2).ok_or(MctpCodecError::InternalError)? = self.eid_setting.0; + *buffer.get_mut(3).ok_or(MctpCodecError::InternalError)? = self.eid_pool_size; + Ok(SetEndpointIdResponse::MCTP_CODEC_MIN_SIZE) + } + + fn decode(buffer: &[u8]) -> Result { + if buffer.len() < SetEndpointIdResponse::MCTP_CODEC_MIN_SIZE { + return Err(MctpCodecError::BufferTooShort); + } + let completion_code = + CompletionCode::from(*buffer.first().ok_or(MctpCodecError::InvalidData)?); + let eid_assignment_status = + EidAssignmentStatus::try_from(*buffer.get(1).ok_or(MctpCodecError::InvalidData)?) + .map_err(|_| MctpCodecError::InvalidData)?; + let eid_allocation_status = + EidAllocationStatus::try_from(*buffer.get(1).ok_or(MctpCodecError::InvalidData)?) + .map_err(|_| MctpCodecError::InvalidData)?; + let eid_setting = Eid::new_normal(*buffer.get(2).ok_or(MctpCodecError::InvalidData)?) + .map_err(|_| MctpCodecError::InvalidData)?; + let eid_pool_size = *buffer.get(3).ok_or(MctpCodecError::InvalidData)?; + + Ok(SetEndpointIdResponse { + completion_code, + eid_assignment_status, + eid_allocation_status, + eid_setting, + eid_pool_size, + }) + } + + const MCTP_CODEC_MIN_SIZE: usize = 4; +} + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub enum EndpointType { + SimpleEndpoint, + BusOwnerOrBridge, +} + +impl From for u8 { + fn from(endpoint_type: EndpointType) -> Self { + match endpoint_type { + EndpointType::SimpleEndpoint => 0x00, + EndpointType::BusOwnerOrBridge => 0x10, + } + } +} + +impl TryFrom for EndpointType { + type Error = CompletionCode; + + fn try_from(value: u8) -> Result { + match (value >> 4) & 0b0000_0011 { + 0x00 => Ok(EndpointType::SimpleEndpoint), + 0x01 => Ok(EndpointType::BusOwnerOrBridge), + _ => Err(CompletionCode::ErrorInvalidData), + } + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub enum EidType { + DynamicEid, + StaticEid, + StaticEidConfigured, + StaticEidAvailable, +} + +impl From for u8 { + fn from(eid_type: EidType) -> Self { + match eid_type { + EidType::DynamicEid => 0x00, + EidType::StaticEid => 0x01, + EidType::StaticEidConfigured => 0x02, + EidType::StaticEidAvailable => 0x03, + } + } +} + +impl TryFrom for EidType { + type Error = CompletionCode; + + fn try_from(value: u8) -> Result { + match value & 0b0000_0011 { + 0x00 => Ok(EidType::DynamicEid), + 0x01 => Ok(EidType::StaticEid), + 0x02 => Ok(EidType::StaticEidConfigured), + 0x03 => Ok(EidType::StaticEidAvailable), + _ => Err(CompletionCode::ErrorInvalidData), + } + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub struct GetEndpointIDResponse { + completion_code: CompletionCode, + eid: Eid, + endpoint_type: EndpointType, + eid_type: EidType, + transport_specific_information: u8, +} + +impl GetEndpointIDResponse { + /// Create a new GetEndpointIDResponse + pub fn new( + completion_code: CompletionCode, + eid: Eid, + endpoint_type: EndpointType, + eid_type: EidType, + transport_specific_informatiuon: u8, + ) -> Self { + Self { + completion_code, + eid, + endpoint_type, + eid_type, + transport_specific_information: transport_specific_informatiuon, + } + } + + /// Creates a new `GetEndpointIDResponse` with an error completion code. + /// + /// # Panics + /// + /// Panics if the completion code is not one of the valid error codes: + /// `Error`, `ErrorInvalidData`, `ErrorInvalidLength`, `ErrorNotReady`, or `ErrorUnsupportedCmd`. + pub const fn new_err(completion_code: CompletionCode) -> Self { + assert!( + matches!( + completion_code, + CompletionCode::Error + | CompletionCode::ErrorInvalidData + | CompletionCode::ErrorInvalidLength + | CompletionCode::ErrorNotReady + | CompletionCode::ErrorUnsupportedCmd + ), + "Completion code must be an error code" + ); + Self { + completion_code, + eid: Eid(0), + endpoint_type: EndpointType::SimpleEndpoint, + eid_type: EidType::DynamicEid, + transport_specific_information: 0, + } + } +} + +impl MctpCodec<'_> for GetEndpointIDResponse { + fn encode(&self, buffer: &mut [u8]) -> Result { + if buffer.len() < GetEndpointIDResponse::MCTP_CODEC_MIN_SIZE { + return Err(MctpCodecError::BufferTooShort); + } + *buffer.get_mut(0).ok_or(MctpCodecError::InternalError)? = self.completion_code.into(); + *buffer.get_mut(1).ok_or(MctpCodecError::InternalError)? = self.eid.0; + *buffer.get_mut(2).ok_or(MctpCodecError::InternalError)? = + u8::from(self.endpoint_type) | u8::from(self.eid_type); + *buffer.get_mut(3).ok_or(MctpCodecError::InternalError)? = + self.transport_specific_information; + + Ok(GetEndpointIDResponse::MCTP_CODEC_MIN_SIZE) + } + + fn decode(buffer: &[u8]) -> Result { + if buffer.len() < GetEndpointIDResponse::MCTP_CODEC_MIN_SIZE { + return Err(MctpCodecError::BufferTooShort); + } + let completion_code = + CompletionCode::from(*buffer.first().ok_or(MctpCodecError::InvalidData)?); + let eid = Eid::new_normal(*buffer.get(1).ok_or(MctpCodecError::InvalidData)?) + .map_err(|_| MctpCodecError::InvalidData)?; + let endpoint_type = + EndpointType::try_from(*buffer.get(2).ok_or(MctpCodecError::InvalidData)?) + .map_err(|_| MctpCodecError::InvalidData)?; + let eid_type = EidType::try_from(*buffer.get(2).ok_or(MctpCodecError::InvalidData)?) + .map_err(|_| MctpCodecError::InvalidData)?; + let transport_specific_informatiuon = *buffer.get(3).ok_or(MctpCodecError::InvalidData)?; + + Ok(GetEndpointIDResponse { + completion_code, + eid, + endpoint_type, + eid_type, + transport_specific_information: transport_specific_informatiuon, + }) + } + + const MCTP_CODEC_MIN_SIZE: usize = 4; +} + +/// Represents the MCTP Message Type, which identifies the format and semantics of the message payload. +/// Each message type is associated with a specific code and protocol specification: +/// +/// Referenced from DSP0239 1.6.0 +/// +/// - `Control (0x00)`: Messages used to support initialization and configuration of MCTP communication within an MCTP network, as specified in DSP0236. +/// - `PLDM (0x01)`: Messages used to convey Platform Level Data Model (PLDM) traffic over MCTP, as specified in DSP0241. +/// - `NcSi (0x02)`: Messages used to convey NC-SI Control traffic over MCTP, as specified in DSP0261. +/// - `Ethernet (0x03)`: Messages used to convey Ethernet traffic over MCTP. See DSP0261. This message type can also be used separately by other specifications. +/// - `NvmeManagement (0x04)`: Messages used to convey NVM Express (NVMe) Management Messages over MCTP, as specified in DSP0235. +/// - `SPDM (0x05)`: Messages used to convey Security Protocol and Data Model Specification (SPDM) traffic over MCTP, as specified in DSP0275. +/// - `PciVdm (0x7E)`: Vendor Defined Message type used to support VDMs where the vendor is identified using a PCI-based vendor ID. The specification of the initial Message Header bytes for this message type is provided within DSP0236. The message body content is specified by the vendor, company, or organization identified by the given vendor ID. +/// - `IanaVdm (0x7F)`: Vendor Defined Message type used to support VDMs where the vendor is identified using an IANA-based vendor ID. This format uses an "Enterprise Number" assigned and maintained by the Internet Assigned Numbers Authority (IANA) as the means of identifying a particular vendor, company, or organization. The specification of the format of this message is given in DSP0236. The message body content is specified by the vendor, company, or organization identified by the given vendor ID. +/// - `Other(u8)`: Reserved for all other codes not explicitly defined above. +/// +/// See the relevant DSP specifications for details on message format and usage. +#[non_exhaustive] +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub enum MctpMessageType { + Control, + Pldm, + NcSi, + Ethernet, + NvmeManagement, + Spdm, + PciVdm, + IanaVdm, + Mctp, + Other(u8), +} + +impl From for MctpMessageType { + fn from(value: u8) -> Self { + match value { + 0x00 => MctpMessageType::Control, + 0x01 => MctpMessageType::Pldm, + 0x02 => MctpMessageType::NcSi, + 0x03 => MctpMessageType::Ethernet, + 0x04 => MctpMessageType::NvmeManagement, + 0x05 => MctpMessageType::Spdm, + 0x7E => MctpMessageType::PciVdm, + 0x7F => MctpMessageType::IanaVdm, + 0xFF => MctpMessageType::Mctp, + other => MctpMessageType::Other(other), + } + } +} + +impl From for u8 { + fn from(msg_type: MctpMessageType) -> Self { + match msg_type { + MctpMessageType::Control => 0x00, + MctpMessageType::Pldm => 0x01, + MctpMessageType::NcSi => 0x02, + MctpMessageType::Ethernet => 0x03, + MctpMessageType::NvmeManagement => 0x04, + MctpMessageType::Spdm => 0x05, + MctpMessageType::PciVdm => 0x7E, + MctpMessageType::IanaVdm => 0x7F, + MctpMessageType::Mctp => 0xFF, + MctpMessageType::Other(value) => value, + } + } +} + +impl From for mctp::MsgType { + fn from(msg_type: MctpMessageType) -> Self { + mctp::MsgType(msg_type.into()) + } +} + +impl From for MctpMessageType { + fn from(msg_type: mctp::MsgType) -> Self { + MctpMessageType::from(msg_type.0) + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +struct GetMCTPVersionSupportRequest { + mctp_message_type: MctpMessageType, +} + +impl GetMCTPVersionSupportRequest { + pub fn new(mctp_message_type: MctpMessageType) -> Self { + Self { mctp_message_type } + } +} + +impl MctpCodec<'_> for GetMCTPVersionSupportRequest { + fn encode(&self, buffer: &mut [u8]) -> Result { + if buffer.len() < GetMCTPVersionSupportRequest::MCTP_CODEC_MIN_SIZE { + return Err(MctpCodecError::BufferTooShort); + } + *buffer.get_mut(0).ok_or(MctpCodecError::InternalError)? = self.mctp_message_type.into(); + Ok(GetMCTPVersionSupportRequest::MCTP_CODEC_MIN_SIZE) + } + + fn decode(buffer: &[u8]) -> Result { + let mctp_message_type = + MctpMessageType::from(*buffer.first().ok_or(MctpCodecError::InvalidData)?); + Ok(Self { mctp_message_type }) + } + + const MCTP_CODEC_MIN_SIZE: usize = 1; +} +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub struct GetMCTPVersionSupportResponse<'a> { + pub completion_code: CompletionCode, + pub version_count: u8, + pub version_codes: &'a [u8], +} + +impl<'a> GetMCTPVersionSupportResponse<'a> { + pub fn new(completion_code: CompletionCode, version_codes: &'a [u8]) -> Self { + Self { + completion_code, + version_count: (version_codes.len() / 4) as u8, + version_codes, + } + } + + /// Creates a new GetMCTPVersionSupportResponse with an error completion code and empty version codes + /// + /// # Panics + /// + /// Panics if the completion code is not one of the valid error codes: + /// `Error`, `ErrorInvalidData`, `ErrorInvalidLength`, `ErrorNotReady`, or `ErrorUnsupportedCmd`. + pub const fn new_err(completion_code: CompletionCode) -> Self { + assert!( + matches!( + completion_code, + CompletionCode::Error + | CompletionCode::ErrorInvalidData + | CompletionCode::ErrorInvalidLength + | CompletionCode::ErrorNotReady + | CompletionCode::ErrorUnsupportedCmd + ), + "Completion code must be an error code" + ); + Self { + completion_code, + version_count: 0, + version_codes: &[], + } + } +} + +impl<'a> GetMCTPVersionSupportResponse<'a> { + /// Gets an iterator over version codes as u32 values. + /// + /// Returns None for malformed chunks that are not exactly 4 bytes. + pub fn get_version_code_iter(&self) -> impl Iterator> + '_ { + self.version_codes.chunks_exact(4).map(|chunk| { + chunk + .try_into() + .ok() + .map(|bytes: [u8; 4]| u32::from_be_bytes(bytes)) + }) + } + + /// Gets the version code at the specified position. + /// + /// Returns None if the position is invalid or if the version codes data is malformed. + pub fn get_version_code_at(&self, pos: usize) -> Option { + self.version_codes + .chunks_exact(4) + .nth(pos) + .and_then(|chunk| { + chunk + .try_into() + .ok() + .map(|bytes: [u8; 4]| u32::from_be_bytes(bytes)) + }) + } +} + +impl<'a> MctpCodec<'a> for GetMCTPVersionSupportResponse<'a> { + fn encode(&self, buffer: &mut [u8]) -> Result { + let max_version_pos = 2 + (self.version_count as usize) * 4; + + if buffer.len() < max_version_pos { + return Err(MctpCodecError::BufferTooShort); + } + + *buffer.get_mut(0).ok_or(MctpCodecError::InternalError)? = self.completion_code.into(); + *buffer.get_mut(1).ok_or(MctpCodecError::InternalError)? = self.version_count; + + let expected_length = 2 + self.version_count as usize * 4; + + buffer + .get_mut(2..expected_length) + .ok_or(MctpCodecError::InternalError)? + .copy_from_slice(self.version_codes); + + Ok(max_version_pos) + } + + fn decode(buffer: &'a [u8]) -> Result { + if buffer.len() < 2 { + return Err(MctpCodecError::BufferTooShort); + } + + let completion_code = + CompletionCode::from(*buffer.first().ok_or(MctpCodecError::InvalidData)?); + let version_count = *buffer.get(1).ok_or(MctpCodecError::InvalidData)?; + + if version_count > 8 { + return Err(MctpCodecError::UnsupportedBufferSize); + } + + let expected_len = 2 + (version_count as usize) * 4; + if buffer.len() < expected_len { + return Err(MctpCodecError::BufferTooShort); + } + + let version_numbers = buffer + .get(2..expected_len) + .ok_or(MctpCodecError::InvalidData)?; + + Ok(Self { + completion_code, + version_count, + version_codes: version_numbers, + }) + } + + const MCTP_CODEC_MIN_SIZE: usize = 2 + 4 * 8; +} + +#[derive(Debug, PartialEq, Eq)] +pub struct GetMctpMessageTypeSupportResponse<'a> { + pub completion_code: CompletionCode, + pub message_type_count: u8, + pub message_types: &'a [u8], +} + +impl<'a> GetMctpMessageTypeSupportResponse<'a> { + pub fn new(completion_code: CompletionCode, message_types_buffer: &'a [u8]) -> Self { + Self { + completion_code, + message_type_count: message_types_buffer.len() as u8, + message_types: message_types_buffer, + } + } + + /// Creates a new `GetMctpMessageTypeSupportResponse` with an error completion code. + /// + /// # Panics + /// + /// Panics if the completion code is not one of the valid error codes: + /// `Error`, `ErrorInvalidData`, `ErrorInvalidLength`, `ErrorNotReady`, or `ErrorUnsupportedCmd`. + pub const fn new_err(completion_code: CompletionCode) -> Self { + assert!( + matches!( + completion_code, + CompletionCode::Error + | CompletionCode::ErrorInvalidData + | CompletionCode::ErrorInvalidLength + | CompletionCode::ErrorNotReady + | CompletionCode::ErrorUnsupportedCmd + ), + "Completion code must be an error code" + ); + Self { + completion_code, + message_type_count: 0, + message_types: &[], + } + } + + pub fn get_message_types_iterator(&self) -> impl Iterator + '_ { + self.message_types + .iter() + .map(|&byte| MctpMessageType::from(byte)) + } + + pub fn get_message_type(&self, index: usize) -> Option { + self.message_types + .get(index) + .map(|&byte| MctpMessageType::from(byte)) + } +} + +impl<'a> MctpCodec<'a> for GetMctpMessageTypeSupportResponse<'a> { + fn encode(&self, buffer: &mut [u8]) -> Result { + let required_len = 2 + (self.message_type_count as usize); + if buffer.len() < required_len { + return Err(MctpCodecError::BufferTooShort); + } + + *buffer.get_mut(0).ok_or(MctpCodecError::InternalError)? = self.completion_code.into(); + *buffer.get_mut(1).ok_or(MctpCodecError::InternalError)? = self.message_type_count; + + buffer + .get_mut(2..required_len) + .ok_or(MctpCodecError::InternalError)? + .copy_from_slice(self.message_types); + + Ok(required_len) + } + + fn decode(buffer: &'a [u8]) -> Result { + let completion_code: CompletionCode = + CompletionCode::from(*buffer.first().ok_or(MctpCodecError::InvalidData)?); + + let message_type_count = *buffer.get(1).ok_or(MctpCodecError::InvalidData)?; + let expected_len = 2 + (message_type_count as usize); + + if buffer.len() < expected_len { + return Err(MctpCodecError::InvalidData); + } + + let message_type_count = *buffer.get(1).ok_or(MctpCodecError::InvalidData)?; + + let message_types = buffer + .get(2..expected_len) + .ok_or(MctpCodecError::InvalidData)?; + + Ok(Self { + completion_code, + message_type_count, + message_types, + }) + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +struct ResolveEndpointIDRequest { + endpoint_id: Eid, +} + +impl ResolveEndpointIDRequest { + pub fn new(endpoint_id: Eid) -> Self { + Self { endpoint_id } + } +} + +impl MctpCodec<'_> for ResolveEndpointIDRequest { + fn encode(&self, buffer: &mut [u8]) -> Result { + *buffer.get_mut(0).ok_or(MctpCodecError::InternalError)? = self.endpoint_id.0; + Ok(1) + } + + fn decode(buffer: &[u8]) -> Result { + let eid = Eid::new_normal(*buffer.first().ok_or(MctpCodecError::InvalidData)?) + .map_err(|_| MctpCodecError::InvalidData)?; + Ok(Self { endpoint_id: eid }) + } +} + +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +struct ResolveEndpointIDResponse { + completion_code: CompletionCode, + bridge_endpoint_id: Eid, + physical_address: [u8; TRANSPORT_ADDRESS_LENGTH], +} + +impl ResolveEndpointIDResponse { + pub fn new( + completion_code: CompletionCode, + bridge_endpoint_id: Eid, + physical_address: [u8; TRANSPORT_ADDRESS_LENGTH], + ) -> Self { + Self { + completion_code, + bridge_endpoint_id, + physical_address, + } + } + + /// Creates a new `ResolveEndpointIDResponse` with an error completion code. + /// + /// # Panics + /// + /// Panics if the completion code is not one of the valid error codes: + /// `Error`, `ErrorInvalidData`, `ErrorInvalidLength`, `ErrorNotReady`, or `ErrorUnsupportedCmd`. + pub const fn new_err(completion_code: CompletionCode) -> Self { + assert!( + matches!( + completion_code, + CompletionCode::Error + | CompletionCode::ErrorInvalidData + | CompletionCode::ErrorInvalidLength + | CompletionCode::ErrorNotReady + | CompletionCode::ErrorUnsupportedCmd + ), + "Completion code must be an error code" + ); + Self { + completion_code, + bridge_endpoint_id: Eid(0), + physical_address: [0; TRANSPORT_ADDRESS_LENGTH], + } + } +} + +impl MctpCodec<'_> + for ResolveEndpointIDResponse +{ + fn encode(&self, buffer: &mut [u8]) -> Result { + let required_len = 2 + TRANSPORT_ADDRESS_LENGTH; + if buffer.len() < required_len { + return Err(MctpCodecError::BufferTooShort); + } + *buffer.get_mut(0).ok_or(MctpCodecError::InternalError)? = self.completion_code.into(); + *buffer.get_mut(1).ok_or(MctpCodecError::InternalError)? = self.bridge_endpoint_id.0; + buffer + .get_mut(2..required_len) + .ok_or(MctpCodecError::InternalError)? + .copy_from_slice(&self.physical_address); + Ok(required_len) + } + + fn decode(buffer: &[u8]) -> Result { + if buffer.len() < 2 { + return Err(MctpCodecError::BufferTooShort); + } + let completion_code = + CompletionCode::from(*buffer.first().ok_or(MctpCodecError::InvalidData)?); + let bridge_endpoint_id = + Eid::new_normal(*buffer.get(1).ok_or(MctpCodecError::InvalidData)?) + .map_err(|_| MctpCodecError::InvalidData)?; + let address_len = buffer.len() - 2; + if address_len != TRANSPORT_ADDRESS_LENGTH { + return Err(MctpCodecError::UnsupportedBufferSize); + } + let mut physical_address = [0u8; TRANSPORT_ADDRESS_LENGTH]; + physical_address.copy_from_slice( + buffer + .get(2..2 + TRANSPORT_ADDRESS_LENGTH) + .ok_or(MctpCodecError::InvalidData)?, + ); + Ok(Self { + completion_code, + bridge_endpoint_id, + physical_address, + }) + } + + const MCTP_CODEC_MIN_SIZE: usize = 2 + TRANSPORT_ADDRESS_LENGTH; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mctp_control_header_new() { + let header = MctpControlHeader::new(true, false, 0x1F, CommandCode::SetEndpointID); + + assert!(header.request); + assert!(!header.datagram); + assert_eq!(header.instance_id, 0x1F); + assert_eq!(header.command_code, CommandCode::SetEndpointID); + } + + #[test] + fn test_mctp_control_header_encode_basic() { + let header = MctpControlHeader::new(true, false, 0x15, CommandCode::GetEndpointID); + + let mut buffer = [0u8; 4]; + let result = header.encode(&mut buffer); + + let mctp_header_expected_size: usize = 2; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), mctp_header_expected_size); + + // Check encoded values + // request=1, datagram=0, instance_id=0x15 -> 0b1001_0101 = 0x95 + assert_eq!(buffer.first().copied().unwrap(), 0x95); + assert_eq!(buffer.get(1).copied().unwrap(), 0x02); // GetEndpointID = 0x02 + } + + #[test] + fn test_mctp_control_header_encode_buffer_too_short() { + let header = MctpControlHeader::new(true, false, 0x10, CommandCode::GetEndpointID); + + let mut buffer = [0u8; 1]; // Too short + let result = header.encode(&mut buffer); + + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), MctpCodecError::BufferTooShort); + } + + #[test] + fn test_mctp_control_header_decode_basic() { + let buffer = [0x95, 0x02]; // request=1, datagram=0, instance_id=0x15, cmd=0x02 + + let result = MctpControlHeader::decode(&buffer); + + assert!(result.is_ok()); + let header = result.unwrap(); + assert!(header.request); + assert!(!header.datagram); + assert_eq!(header.instance_id, 0x15); + assert_eq!(header.command_code, CommandCode::GetEndpointID); + } + + #[test] + fn test_mctp_control_header_decode_all_flags() { + let buffer = [0xFF, 0x03]; // request=1, datagram=1, instance_id=0x3F, cmd=0x03 + + let result = MctpControlHeader::decode(&buffer); + + assert!(result.is_ok()); + let header = result.unwrap(); + assert!(header.request); + assert!(header.datagram); + assert_eq!(header.instance_id, 0x3F); + assert_eq!(header.command_code, CommandCode::GetEndpointUUID); + } + + #[test] + fn test_mctp_control_header_decode_no_flags() { + let buffer = [0x00, 0x01]; // All flags false, cmd=0x01 + + let result = MctpControlHeader::decode(&buffer); + + assert!(result.is_ok()); + let header = result.unwrap(); + assert!(!header.request); + assert!(!header.datagram); + assert_eq!(header.instance_id, 0x00); + assert_eq!(header.command_code, CommandCode::SetEndpointID); + } + + #[test] + fn test_mctp_control_header_decode_buffer_too_short() { + let buffer = [0x95]; // Only 1 byte + + let result = MctpControlHeader::decode(&buffer); + + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), MctpCodecError::InvalidData); + } + + #[test] + fn test_mctp_control_header_decode_unknown_command() { + let buffer = [0x80, 0xFF]; // request=1, datagram=0, instance_id=0, unknown cmd + + let result = MctpControlHeader::decode(&buffer); + + assert!(result.is_ok()); + let header = result.unwrap(); + assert!(header.request); + assert!(!header.datagram); + assert_eq!(header.instance_id, 0x00); + assert_eq!(header.command_code, CommandCode::TransportSpecific(0xFF)); + } + + #[test] + fn test_mctp_control_header_round_trip() { + let original = MctpControlHeader::new(true, true, 0x2A, CommandCode::GetMCTPVersionSupport); + + let mut buffer = [0u8; 4]; + let encode_result = original.encode(&mut buffer); + assert!(encode_result.is_ok()); + + let decode_result = MctpControlHeader::decode(&buffer); + assert!(decode_result.is_ok()); + + let decoded = decode_result.unwrap(); + assert_eq!(original, decoded); + } + + #[test] + fn test_mctp_control_header_instance_id_boundary() { + // Test maximum valid instance_id (6 bits = 0x3F) + let header = MctpControlHeader::new(false, false, 0x3F, CommandCode::GetEndpointID); + + let mut buffer = [0u8; 4]; + let encode_result = header.encode(&mut buffer); + assert!(encode_result.is_ok()); + + let decode_result = MctpControlHeader::decode(&buffer); + assert!(decode_result.is_ok()); + + let decoded = decode_result.unwrap(); + assert_eq!(decoded.instance_id, 0x3F); + } + + #[test] + fn test_mctp_control_header_various_commands() { + let test_cases = vec![ + CommandCode::SetEndpointID, + CommandCode::GetEndpointID, + CommandCode::GetEndpointUUID, + CommandCode::GetMCTPVersionSupport, + CommandCode::GetMessageTypeSupport, + CommandCode::ResolveEndpointID, + CommandCode::QueryHop, + CommandCode::TransportSpecific(0xF5), + CommandCode::Unknown(0x42), + ]; + + for cmd in test_cases { + let header = MctpControlHeader::new(true, false, 0x10, cmd); + + let mut buffer = [0u8; 4]; + let encode_result = header.encode(&mut buffer); + assert!(encode_result.is_ok()); + + let decode_result = MctpControlHeader::decode(&buffer); + assert!(decode_result.is_ok()); + + let decoded = decode_result.unwrap(); + assert_eq!(decoded.command_code, cmd); + } + } + + #[test] + fn test_mctp_control_header_min_size_constant() { + assert_eq!(MctpControlHeader::MCTP_CODEC_MIN_SIZE, 2); + } + + #[test] + fn test_mctp_message_type_from_u8() { + // Test all defined variants + assert_eq!(MctpMessageType::from(0x00), MctpMessageType::Control); + assert_eq!(MctpMessageType::from(0x01), MctpMessageType::Pldm); + assert_eq!(MctpMessageType::from(0x02), MctpMessageType::NcSi); + assert_eq!(MctpMessageType::from(0x03), MctpMessageType::Ethernet); + assert_eq!(MctpMessageType::from(0x04), MctpMessageType::NvmeManagement); + assert_eq!(MctpMessageType::from(0x05), MctpMessageType::Spdm); + assert_eq!(MctpMessageType::from(0x7E), MctpMessageType::PciVdm); + assert_eq!(MctpMessageType::from(0x7F), MctpMessageType::IanaVdm); + assert_eq!(MctpMessageType::from(0xFF), MctpMessageType::Mctp); + + // Test Other variant for undefined values + assert_eq!(MctpMessageType::from(0x42), MctpMessageType::Other(0x42)); + assert_eq!(MctpMessageType::from(0x80), MctpMessageType::Other(0x80)); + assert_eq!(MctpMessageType::from(0x10), MctpMessageType::Other(0x10)); + } + + #[test] + fn test_mctp_message_type_to_u8() { + // Test all defined variants + assert_eq!(u8::from(MctpMessageType::Control), 0x00); + assert_eq!(u8::from(MctpMessageType::Pldm), 0x01); + assert_eq!(u8::from(MctpMessageType::NcSi), 0x02); + assert_eq!(u8::from(MctpMessageType::Ethernet), 0x03); + assert_eq!(u8::from(MctpMessageType::NvmeManagement), 0x04); + assert_eq!(u8::from(MctpMessageType::Spdm), 0x05); + assert_eq!(u8::from(MctpMessageType::PciVdm), 0x7E); + assert_eq!(u8::from(MctpMessageType::IanaVdm), 0x7F); + assert_eq!(u8::from(MctpMessageType::Mctp), 0xFF); + + // Test Other variant + assert_eq!(u8::from(MctpMessageType::Other(0x42)), 0x42); + assert_eq!(u8::from(MctpMessageType::Other(0x99)), 0x99); + } + + #[test] + fn test_mctp_message_type_round_trip_u8() { + // Test round-trip conversions for all defined values + let test_values = vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x7E, 0x7F, 0xFF]; + + for value in test_values { + let msg_type = MctpMessageType::from(value); + let back_to_u8 = u8::from(msg_type); + assert_eq!(value, back_to_u8); + } + + // Test round-trip for Other variants + let other_values = vec![0x42, 0x80, 0x10, 0x20, 0x60, 0x90]; + for value in other_values { + let msg_type = MctpMessageType::from(value); + let back_to_u8 = u8::from(msg_type); + assert_eq!(value, back_to_u8); + } + } + + #[test] + fn test_mctp_message_type_to_mctp_msgtype() { + // Test conversion to mctp::MsgType + let control_type = MctpMessageType::Control; + let mctp_msg_type: mctp::MsgType = control_type.into(); + assert_eq!(mctp_msg_type.0, 0x00); + + let pldm_type = MctpMessageType::Pldm; + let mctp_msg_type: mctp::MsgType = pldm_type.into(); + assert_eq!(mctp_msg_type.0, 0x01); + + let other_type = MctpMessageType::Other(0x42); + let mctp_msg_type: mctp::MsgType = other_type.into(); + assert_eq!(mctp_msg_type.0, 0x42); + } + + #[test] + fn test_mctp_msgtype_to_message_type() { + // Test conversion from mctp::MsgType + let mctp_msg_type = mctp::MsgType(0x00); + let msg_type = MctpMessageType::from(mctp_msg_type); + assert_eq!(msg_type, MctpMessageType::Control); + + let mctp_msg_type = mctp::MsgType(0x7F); + let msg_type = MctpMessageType::from(mctp_msg_type); + assert_eq!(msg_type, MctpMessageType::IanaVdm); + + let mctp_msg_type = mctp::MsgType(0x42); + let msg_type = MctpMessageType::from(mctp_msg_type); + assert_eq!(msg_type, MctpMessageType::Other(0x42)); + } + + #[test] + fn test_mctp_message_type_round_trip_mctp_msgtype() { + // Test round-trip conversion through mctp::MsgType + let test_types = vec![ + MctpMessageType::Control, + MctpMessageType::Pldm, + MctpMessageType::NcSi, + MctpMessageType::Ethernet, + MctpMessageType::NvmeManagement, + MctpMessageType::Spdm, + MctpMessageType::PciVdm, + MctpMessageType::IanaVdm, + MctpMessageType::Mctp, + MctpMessageType::Other(0x42), + MctpMessageType::Other(0x99), + ]; + + for original_type in test_types { + let mctp_msg_type: mctp::MsgType = original_type.into(); + let back_to_msg_type = MctpMessageType::from(mctp_msg_type); + assert_eq!(original_type, back_to_msg_type); + } + } + + #[test] + fn test_mctp_message_type_edge_cases() { + // Test boundary values + assert_eq!(MctpMessageType::from(0x00), MctpMessageType::Control); + assert_eq!(MctpMessageType::from(0xFF), MctpMessageType::Mctp); + + // Test values just before and after defined ranges + assert_eq!(MctpMessageType::from(0x06), MctpMessageType::Other(0x06)); + assert_eq!(MctpMessageType::from(0x7D), MctpMessageType::Other(0x7D)); + assert_eq!(MctpMessageType::from(0xFE), MctpMessageType::Other(0xFE)); + } + + #[test] + fn test_set_endpoint_id_response_constructor() { + let eid = Eid::new_normal(20).unwrap(); + let response = SetEndpointIdResponse::new( + CompletionCode::Success, + EidAssignmentStatus::Accepted, + EidAllocationStatus::NoEidPoolUsed, + eid, + 5, + ); + assert_eq!(response.completion_code, CompletionCode::Success); + assert_eq!(response.eid_assignment_status, EidAssignmentStatus::Accepted); + assert_eq!(response.eid_allocation_status, EidAllocationStatus::NoEidPoolUsed); + assert_eq!(response.eid_setting, eid); + assert_eq!(response.eid_pool_size, 5); + } + + #[test] + fn test_get_endpoint_id_response_constructor() { + let eid = Eid::new_normal(30).unwrap(); + let response = GetEndpointIDResponse::new( + CompletionCode::Success, + eid, + EndpointType::BusOwnerOrBridge, + EidType::StaticEid, + 42, + ); + assert_eq!(response.completion_code, CompletionCode::Success); + assert_eq!(response.eid, eid); + assert_eq!(response.endpoint_type, EndpointType::BusOwnerOrBridge); + assert_eq!(response.eid_type, EidType::StaticEid); + assert_eq!(response.transport_specific_information, 42); + } + + #[test] + fn test_get_mctp_version_support_request_constructor() { + let request = GetMCTPVersionSupportRequest::new(MctpMessageType::Pldm); + assert_eq!(request.mctp_message_type, MctpMessageType::Pldm); + } + + #[test] + fn test_get_mctp_version_support_response_constructor() { + let version_codes = &[1, 0, 0, 0, 2, 0, 0, 0]; + let response = GetMCTPVersionSupportResponse::new(CompletionCode::Success, version_codes); + assert_eq!(response.completion_code, CompletionCode::Success); + assert_eq!(response.version_count, 2); + assert_eq!(response.version_codes, version_codes); + } + + #[test] + fn test_get_mctp_message_type_support_response_constructor() { + let message_types = &[0x00, 0x01, 0x02]; + let response = GetMctpMessageTypeSupportResponse::new(CompletionCode::Success, message_types); + assert_eq!(response.completion_code, CompletionCode::Success); + assert_eq!(response.message_type_count, 3); + assert_eq!(response.message_types, message_types); + } + + #[test] + fn test_resolve_endpoint_id_request_constructor() { + let eid = Eid::new_normal(40).unwrap(); + let request = ResolveEndpointIDRequest::new(eid); + assert_eq!(request.endpoint_id, eid); + } + + #[test] + fn test_resolve_endpoint_id_response_constructor() { + let bridge_eid = Eid::new_normal(50).unwrap(); + let physical_address = [1, 2, 3, 4]; + let response = ResolveEndpointIDResponse::<4>::new( + CompletionCode::Success, + bridge_eid, + physical_address, + ); + assert_eq!(response.completion_code, CompletionCode::Success); + assert_eq!(response.bridge_endpoint_id, bridge_eid); + assert_eq!(response.physical_address, physical_address); + } + + // Tests for error constructors + #[test] + fn test_set_endpoint_id_response_new_err() { + let response = SetEndpointIdResponse::new_err(CompletionCode::Error); + assert_eq!(response.completion_code, CompletionCode::Error); + assert_eq!(response.eid_assignment_status, EidAssignmentStatus::Rejected); + assert_eq!(response.eid_allocation_status, EidAllocationStatus::NoEidPoolUsed); + assert_eq!(response.eid_setting, Eid(0)); + assert_eq!(response.eid_pool_size, 0); + } + + #[test] + fn test_get_endpoint_id_response_new_err() { + let response = GetEndpointIDResponse::new_err(CompletionCode::ErrorInvalidData); + assert_eq!(response.completion_code, CompletionCode::ErrorInvalidData); + assert_eq!(response.eid, Eid(0)); + assert_eq!(response.endpoint_type, EndpointType::SimpleEndpoint); + assert_eq!(response.eid_type, EidType::DynamicEid); + assert_eq!(response.transport_specific_information, 0); + } + + #[test] + fn test_get_mctp_version_support_response_new_err() { + let response = GetMCTPVersionSupportResponse::new_err(CompletionCode::ErrorUnsupportedCmd); + assert_eq!(response.completion_code, CompletionCode::ErrorUnsupportedCmd); + assert_eq!(response.version_count, 0); + assert!(response.version_codes.is_empty()); + } + + #[test] + fn test_get_mctp_message_type_support_response_new_err() { + let response = GetMctpMessageTypeSupportResponse::new_err(CompletionCode::ErrorNotReady); + assert_eq!(response.completion_code, CompletionCode::ErrorNotReady); + assert_eq!(response.message_type_count, 0); + assert!(response.message_types.is_empty()); + } + + #[test] + fn test_resolve_endpoint_id_response_new_err() { + let response = ResolveEndpointIDResponse::<6>::new_err(CompletionCode::ErrorInvalidLength); + assert_eq!(response.completion_code, CompletionCode::ErrorInvalidLength); + assert_eq!(response.bridge_endpoint_id, Eid(0)); + assert_eq!(response.physical_address, [0; 6]); + } + + // Encode/decode round-trip tests + #[test] + fn test_set_endpoint_id_request_round_trip() { + let eid = Eid::new_normal(15).unwrap(); + let operation = SetEndpointIDOperation::ForceEid(eid); + let original = SetEndpointIdRequest(operation); + + let mut buffer = [0u8; 10]; + let encoded_size = original.encode(&mut buffer).unwrap(); + let decoded = SetEndpointIdRequest::decode( + buffer.get(..encoded_size).ok_or("Buffer slice error").unwrap() + ).unwrap(); + + assert_eq!(original, decoded); + } + + #[test] + fn test_set_endpoint_id_response_round_trip() { + let eid = Eid::new_normal(25).unwrap(); + let original = SetEndpointIdResponse::new( + CompletionCode::Success, + EidAssignmentStatus::Accepted, + EidAllocationStatus::EidPoolAllcotationEstablished, + eid, + 8, + ); + + let mut buffer = [0u8; 10]; + let encoded_size = original.encode(&mut buffer).unwrap(); + let decoded = SetEndpointIdResponse::decode( + buffer.get(..encoded_size).ok_or("Buffer slice error").unwrap() + ).unwrap(); + + assert_eq!(original, decoded); + } + + #[test] + fn test_get_endpoint_id_response_round_trip() { + let eid = Eid::new_normal(35).unwrap(); + let original = GetEndpointIDResponse::new( + CompletionCode::Success, + eid, + EndpointType::BusOwnerOrBridge, + EidType::StaticEidConfigured, + 123, + ); + + let mut buffer = [0u8; 10]; + let encoded_size = original.encode(&mut buffer).unwrap(); + let decoded = GetEndpointIDResponse::decode( + buffer.get(..encoded_size).ok_or("Buffer slice error").unwrap() + ).unwrap(); + + assert_eq!(original, decoded); + } + + #[test] + fn test_get_mctp_version_support_request_round_trip() { + let original = GetMCTPVersionSupportRequest::new(MctpMessageType::Spdm); + + let mut buffer = [0u8; 10]; + let encoded_size = original.encode(&mut buffer).unwrap(); + let decoded = GetMCTPVersionSupportRequest::decode( + buffer.get(..encoded_size).ok_or("Buffer slice error").unwrap() + ).unwrap(); + + assert_eq!(original, decoded); + } + + #[test] + fn test_get_mctp_version_support_response_round_trip() { + let version_codes = &[1, 0, 0, 0, 3, 0, 0, 0, 5, 0, 0, 0]; + let original = GetMCTPVersionSupportResponse::new(CompletionCode::Success, version_codes); + + let mut buffer = [0u8; 50]; + let encoded_size = original.encode(&mut buffer).unwrap(); + let decoded = GetMCTPVersionSupportResponse::decode( + buffer.get(..encoded_size).ok_or("Buffer slice error").unwrap() + ).unwrap(); + + assert_eq!(original, decoded); + } + + #[test] + fn test_get_mctp_message_type_support_response_round_trip() { + let message_types = &[0x00, 0x01, 0x05, 0x7E]; + let original = GetMctpMessageTypeSupportResponse::new(CompletionCode::Success, message_types); + + let mut buffer = [0u8; 20]; + let encoded_size = original.encode(&mut buffer).unwrap(); + let decoded = GetMctpMessageTypeSupportResponse::decode( + buffer.get(..encoded_size).ok_or("Buffer slice error").unwrap() + ).unwrap(); + + assert_eq!(original, decoded); + } + + #[test] + fn test_resolve_endpoint_id_request_round_trip() { + let eid = Eid::new_normal(45).unwrap(); + let original = ResolveEndpointIDRequest::new(eid); + + let mut buffer = [0u8; 10]; + let encoded_size = original.encode(&mut buffer).unwrap(); + let decoded = ResolveEndpointIDRequest::decode( + buffer.get(..encoded_size).ok_or("Buffer slice error").unwrap() + ).unwrap(); + + assert_eq!(original, decoded); + } + + #[test] + fn test_resolve_endpoint_id_response_round_trip() { + let bridge_eid = Eid::new_normal(55).unwrap(); + let physical_address = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]; + let original = ResolveEndpointIDResponse::new( + CompletionCode::Success, + bridge_eid, + physical_address, + ); + + let mut buffer = [0u8; 20]; + let encoded_size = original.encode(&mut buffer).unwrap(); + let decoded = ResolveEndpointIDResponse::<6>::decode( + buffer.get(..encoded_size).ok_or("Buffer slice error").unwrap() + ).unwrap(); + + assert_eq!(original, decoded); + } + + #[test] + fn test_mctp_control_message_round_trip() { + let header = MctpControlHeader::new(true, false, 0x20, CommandCode::GetEndpointUUID); + let message_body = &[0x12, 0x34, 0x56, 0x78]; + let original = MctpControlMessage::new(header, message_body); + + let mut buffer = [0u8; 20]; + let encoded_size = original.encode(&mut buffer).unwrap(); + let decoded = MctpControlMessage::decode( + buffer.get(..encoded_size).ok_or("Buffer slice error").unwrap() + ).unwrap(); + + assert_eq!(original, decoded); + } + + // Test error cases for encode/decode + #[test] + fn test_encode_decode_error_cases() { + // Test buffer too short for SetEndpointIdRequest + let eid = Eid::new_normal(10).unwrap(); + let operation = SetEndpointIDOperation::SetEid(eid); + let request = SetEndpointIdRequest(operation); + + let mut small_buffer = [0u8; 1]; + assert_eq!(request.encode(&mut small_buffer), Err(MctpCodecError::BufferTooShort)); + + // Test invalid data for decode + let invalid_buffer = [0xFF]; // Too short + assert_eq!(SetEndpointIdRequest::decode(invalid_buffer.as_slice()), Err(MctpCodecError::BufferTooShort)); + } +} diff --git a/standalone/Cargo.toml b/standalone/Cargo.toml new file mode 100644 index 0000000..4042179 --- /dev/null +++ b/standalone/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "standalone" +version = "0.1.0" +edition = "2024" +authors = ["OpenPRoT Contributors"] +license = "Apache-2.0" +repository = "https://github.com/OpenPRoT/mctp-lib" +description = "Standalone std implementation of mctp-lib" + +[dependencies] +mctp-lib = { path = "../" } +mctp = { git = "https://github.com/OpenPRoT/mctp-rs.git", branch = "sync-features" } +embedded-io-adapters = { version = "0.6.0", features = ["std"] } diff --git a/standalone/src/lib.rs b/standalone/src/lib.rs new file mode 100644 index 0000000..c5e917f --- /dev/null +++ b/standalone/src/lib.rs @@ -0,0 +1,344 @@ +// Copyright 2025 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Standalone implementation of mctp-lib usind [std] platform abstactions. +//! +//! Intended for use in examples and tests. + +pub mod serial_sender; + +use std::collections::HashMap; +use std::sync::{Arc, Condvar, Mutex}; +use std::time::{Duration, Instant}; + +use mctp::{Eid, Error, Listener, MsgIC, MsgType, ReqChannel, RespChannel, Tag}; +use mctp_lib::{AppCookie, Router, Sender}; + +const MAX_LISTENER_HANDLES: usize = 128; +const MAX_REQ_HANDLES: usize = 128; + +/// STD MCTP stack +/// +/// Encapsulates a inner [Router] in a thread safe and sharable manner. +/// Provides implementations for the [mctp] traits that hold references to the `stack`. +pub struct Stack { + inner: Arc>>, + /// Notifiers to inform _requests_ and _listeners_ about new messages. + notifiers: Arc>>>, + start_time: Instant, +} + +/// A request implementing [ReqChannel] +#[derive(Debug)] +pub struct Request { + /// Thread safe reference to a stack + stack: Arc>>, + cookie: AppCookie, + /// The [Condvar] that nofifies the request once the response is available + notifier: Arc, + timeout: Option, + tag: Option, +} +/// A listener implementing [Listener] +#[derive(Debug)] +pub struct ReqListener { + stack: Arc>>, + notifiers: Arc>>>, + cookie: AppCookie, + notifier: Arc, + timeout: Option, +} +/// A response for a request received by a [ReqListener] +#[derive(Debug)] +pub struct Response { + stack: Arc>>, + notifiers: Arc>>>, + tag: Tag, + typ: MsgType, + remote_eid: Eid, +} + +impl Stack { + pub fn new(outbound: S) -> Self { + let inner = Router::new(Eid(0), 0, outbound); + Self { + inner: Arc::new(Mutex::new(inner)), + notifiers: Arc::new(Mutex::new(HashMap::new())), + start_time: Instant::now(), + } + } + pub fn request(&mut self, dest: Eid, timeout: Option) -> mctp::Result> { + let handle = self + .inner + .lock() + .map_err(|_| Error::InternalError)? + .req(dest)?; + let mut notifiers = self.notifiers.lock().map_err(|_| Error::InternalError)?; + let notifier = Arc::new(Condvar::new()); + notifiers.insert(handle, Arc::clone(¬ifier)); + Ok(Request { + stack: self.inner.clone(), + cookie: handle, + notifier, + timeout, + tag: None, + }) + } + pub fn listener( + &mut self, + typ: MsgType, + timeout: Option, + ) -> mctp::Result> { + let handle = self + .inner + .lock() + .map_err(|_| Error::InternalError)? + .listener(typ)?; + let mut notifiers = self.notifiers.lock().map_err(|_| Error::InternalError)?; + let notifier = Arc::new(Condvar::new()); + notifiers.insert(handle, Arc::clone(¬ifier)); + Ok(ReqListener { + stack: self.inner.clone(), + cookie: handle, + notifier, + timeout, + notifiers: Arc::clone(&self.notifiers), + }) + } + + pub fn inbound(&mut self, pkt: &[u8]) -> Result<(), Error> { + let cookie = self + .inner + .lock() + .map_err(|_| Error::InternalError)? + .inbound(pkt)?; + if let Some(handle) = cookie { + let notifiers = self.notifiers.lock().map_err(|_| Error::InternalError)?; + let notifier = notifiers.get(&handle); + notifier.inspect(|c| c.notify_all()); + } + Ok(()) + } + + /// Call the update function of the inner stack with the current timestamp + /// + /// Convenience function that gets the current timestamp by calculating the duration since the stack was initialized (using [std::time]). + pub fn update(&mut self) -> Result { + self.inner + .lock() + .map_err(|_| Error::InternalError)? + .update(Instant::now().duration_since(self.start_time).as_millis() as u64) + } + + /// Set the stacks EID + pub fn set_eid(&mut self, eid: Eid) -> Result<(), Error> { + self.inner + .lock() + .map_err(|_| Error::InternalError)? + .set_eid(eid) + } +} + +impl Clone for Stack { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + notifiers: Arc::clone(&self.notifiers), + start_time: self.start_time, + } + } +} + +impl ReqChannel for Request { + fn send_vectored( + &mut self, + typ: mctp::MsgType, + integrity_check: mctp::MsgIC, + bufs: &[&[u8]], + ) -> mctp::Result<()> { + let tag = self + .stack + .lock() + .map_err(|_| Error::InternalError)? + .send_vectored(None, typ, None, integrity_check, self.cookie, bufs)?; + self.tag = Some(tag); + Ok(()) + } + + fn recv<'f>( + &mut self, + buf: &'f mut [u8], + ) -> mctp::Result<(mctp::MsgType, mctp::MsgIC, &'f mut [u8])> { + let Some(tag) = self.tag else { + return Err(Error::BadArgument); + }; + let mut stack = self.stack.lock().unwrap(); + loop { + if let Some(mut msg) = stack.recv(self.cookie) { + if msg.tag.tag() != tag.tag() { + msg.retain(); + return Err(Error::InternalError); + } + buf.get_mut(..msg.payload.len()) + .ok_or(Error::NoSpace)? + .copy_from_slice(msg.payload); + return Ok((msg.typ, msg.ic, &mut buf[..msg.payload.len()])); + } + if let Some(timeout) = self.timeout { + let (stack_result, timeout_result) = + self.notifier.wait_timeout(stack, timeout).unwrap(); + if timeout_result.timed_out() { + return Err(Error::TimedOut); + } else { + stack = stack_result; + } + } else { + stack = self.notifier.wait(stack).unwrap(); + } + } + } + + fn remote_eid(&self) -> Eid { + todo!() + } +} + +impl Listener for ReqListener { + type RespChannel<'a> + = Response + where + Self: 'a; + + fn recv<'f>( + &mut self, + buf: &'f mut [u8], + ) -> mctp::Result<(MsgType, MsgIC, &'f mut [u8], Self::RespChannel<'_>)> { + let mut stack = self.stack.lock().unwrap(); + loop { + if let Some(msg) = stack.recv(self.cookie) { + buf.get_mut(..msg.payload.len()) + .ok_or(Error::NoSpace)? + .copy_from_slice(msg.payload); + let resp = Response { + stack: Arc::clone(&self.stack), + tag: Tag::Unowned(msg.tag.tag()), + remote_eid: msg.source, + typ: msg.typ, + notifiers: Arc::clone(&self.notifiers), + }; + return Ok((msg.typ, msg.ic, &mut buf[..msg.payload.len()], resp)); + } + if let Some(timeout) = self.timeout { + let (stack_result, timeout_result) = + self.notifier.wait_timeout(stack, timeout).unwrap(); + if timeout_result.timed_out() { + return Err(Error::TimedOut); + } else { + stack = stack_result; + } + } else { + stack = self.notifier.wait(stack).unwrap(); + } + } + } +} + +impl RespChannel for Response { + type ReqChannel = Request; + + fn send_vectored(&mut self, integrity_check: MsgIC, bufs: &[&[u8]]) -> mctp::Result<()> { + self.stack + .lock() + .map_err(|_| Error::InternalError)? + .send_vectored( + Some(self.remote_eid), + self.typ, + Some(self.tag), + integrity_check, + AppCookie(255), // TODO improve this in mctp-lib + bufs, + )?; + Ok(()) + } + + fn remote_eid(&self) -> Eid { + self.remote_eid + } + + fn req_channel(&self) -> mctp::Result { + let handle = self + .stack + .lock() + .map_err(|_| Error::InternalError)? + .req(self.remote_eid)?; + let mut notifiers = self.notifiers.lock().map_err(|_| Error::InternalError)?; + let notifier = Arc::new(Condvar::new()); + notifiers.insert(handle, Arc::clone(¬ifier)); + Ok(Request { + stack: self.stack.clone(), + cookie: handle, + notifier, + timeout: None, + tag: None, + }) + } +} + +pub mod util { + use std::{ + io::{BufReader, Read}, + thread::sleep, + time::Duration, + }; + + use crate::Stack; + use embedded_io_adapters::std::FromStd; + use mctp_lib::Sender; + + /// Loop that updates the `stack` periodically + /// + /// The stack gets updated atleast once every 100 ms. + pub fn update_loop(mut stack: Stack) -> ! { + loop { + let timeout = match stack.update() { + Ok(t) => t, + Err(e) => { + println!("Error updating stack: {e}"); + 100 + } + }; + + sleep(Duration::from_millis(timeout)); + } + } + + /// Loop that reads packets from the `serial` line into the stack + pub fn inbound_loop(mut stack: Stack, serial: R) -> ! { + let mut reader = FromStd::new(BufReader::new(serial)); + let mut serial_transport = mctp_lib::serial::MctpSerialHandler::new(); + loop { + let Ok(pkt) = serial_transport + .recv_sync(&mut reader) + .inspect_err(|e| println!("Error receiving serial data: {e}")) + else { + continue; + }; + + stack + .inbound(pkt) + .inspect_err(|e| println!("Error processing inbound packet: {e}")) + .ok(); + } + } +} diff --git a/standalone/src/serial_sender.rs b/standalone/src/serial_sender.rs new file mode 100644 index 0000000..fa25bba --- /dev/null +++ b/standalone/src/serial_sender.rs @@ -0,0 +1,58 @@ +// Copyright 2025 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use embedded_io_adapters::std::FromStd; +use mctp::Error; +use std::io::Write; + +use mctp_lib::{Sender, fragment::SendOutput, serial::MctpSerialHandler}; + +pub struct IoSerialSender { + writer: FromStd, + serial_handler: MctpSerialHandler, +} +impl IoSerialSender { + pub fn new(writer: W) -> Self { + IoSerialSender { + writer: FromStd::new(writer), + serial_handler: MctpSerialHandler::new(), + } + } +} + +impl Sender for IoSerialSender { + fn send_vectored( + &mut self, + _eid: mctp::Eid, + mut fragmenter: mctp_lib::fragment::Fragmenter, + payload: &[&[u8]], + ) -> mctp::Result { + loop { + let mut pkt = [0; mctp_lib::serial::MTU_MAX]; + let fragment = fragmenter.fragment_vectored(payload, &mut pkt); + match fragment { + SendOutput::Packet(items) => { + self.serial_handler.send_sync(items, &mut self.writer)?; + self.writer.inner_mut().flush().map_err(Error::Io)?; + } + SendOutput::Complete { tag, cookie: _ } => return Ok(tag), + SendOutput::Error { err, cookie: _ } => return Err(err), + } + } + } + + fn get_mtu(&self) -> usize { + mctp_lib::serial::MTU_MAX + } +}