From 8bef238a0f18287aa49b2ef33afdd9de5bdec374 Mon Sep 17 00:00:00 2001 From: Robert Zieba Date: Tue, 9 Dec 2025 14:55:01 -0800 Subject: [PATCH 1/2] Remove possible panic paths --- src/asynchronous/fw_update.rs | 35 +++++++++++++----- src/asynchronous/internal/command.rs | 4 +- src/asynchronous/internal/mod.rs | 35 ++++++++++-------- src/asynchronous/interrupt.rs | 6 +-- src/command/mod.rs | 51 ++++++++++---------------- src/lib.rs | 18 ++++----- src/registers/autonegotiate_sink.rs | 17 ++++++++- src/registers/rx_caps.rs | 55 ++++++++++++++-------------- src/stream.rs | 2 + 9 files changed, 120 insertions(+), 103 deletions(-) diff --git a/src/asynchronous/fw_update.rs b/src/asynchronous/fw_update.rs index 4c8706d..4439764 100644 --- a/src/asynchronous/fw_update.rs +++ b/src/asynchronous/fw_update.rs @@ -233,13 +233,11 @@ impl BorrowedUpdaterInProgress { controllers: &mut [&mut T], data: &[u8], ) -> Result<(), Error> { - if controllers.is_empty() { - return Err(PdError::InvalidParams.into()); - } - trace!("Controllers: Sending burst write"); let update_args = self.update_args.ok_or(Error::Pd(PdError::InvalidParams))?; - if let Err(e) = controllers[0] + if let Err(e) = controllers + .get_mut(0) + .ok_or(PdError::InvalidParams)? .fw_update_burst_write(update_args.broadcast_u16_address as u8, data) .await { @@ -434,7 +432,10 @@ impl BorrowedUpdaterInProgress { read_result.read_data, ); - self.args_buffer[current..current + read_len].copy_from_slice(read_result.read_data); + self.args_buffer + .get_mut(current..current + read_len) + .ok_or(PdError::InvalidParams)? + .copy_from_slice(read_result.read_data); if read_result.is_complete() { // We have the full header metadata @@ -498,7 +499,10 @@ impl BorrowedUpdaterInProgress { self.block_args = None; let current = read_result.read_state.current; let read_len = read_result.read_data.len(); - self.args_buffer[current..current + read_len].copy_from_slice(read_result.read_data); + self.args_buffer + .get_mut(current..current + read_len) + .ok_or(PdError::InvalidParams)? + .copy_from_slice(read_result.read_data); if read_result.is_complete() { // We have the full header metadata @@ -573,7 +577,10 @@ impl BorrowedUpdaterInProgress { ) -> Result, Error> { let current = read_result.read_state.current; let read_len = read_result.read_data.len(); - self.args_buffer[current..current + read_len].copy_from_slice(read_result.read_data); + self.args_buffer + .get_mut(current..current + read_len) + .ok_or(PdError::InvalidParams)? + .copy_from_slice(read_result.read_data); self.fw_update_burst_write(controllers, read_result.read_data).await?; if read_result.is_complete() { // We have the full image size @@ -702,14 +709,22 @@ pub async fn perform_fw_update_borrowed( // Disable all interrupts while we're entering FW update mode // These go in the second half of the interrupt_guards array so they get dropped last - disable_all_interrupts(controllers, &mut interrupt_guards[half..]).await?; + disable_all_interrupts( + controllers, + interrupt_guards.get_mut(half..).ok_or(PdError::InvalidParams)?, + ) + .await?; info!("Starting update"); let result = updater.start_fw_update(controllers, delay).await; info!("Update started"); // Re-enable interrupts on port 0 only // These go in the first half of the interrupt_guards array so they get dropped first - enable_port0_interrupts(controllers, &mut interrupt_guards[0..half]).await?; + enable_port0_interrupts( + controllers, + interrupt_guards.get_mut(0..half).ok_or(PdError::InvalidParams)?, + ) + .await?; match result { Err(e) => { diff --git a/src/asynchronous/internal/command.rs b/src/asynchronous/internal/command.rs index fa785d5..921ff6a 100644 --- a/src/asynchronous/internal/command.rs +++ b/src/asynchronous/internal/command.rs @@ -96,14 +96,14 @@ impl Tps6699x { debug!("read_command_result: ret: {:?}", ret); // Overwrite return value if let Some(data) = data { - data.copy_from_slice(&buf[1..=data.len()]); + data.copy_from_slice(buf.get(1..=data.len()).ok_or(PdError::InvalidParams)?); } Ok(ret) } else { // No return value to check debug!("read_command_result: Done"); if let Some(data) = data { - data.copy_from_slice(&buf[..data.len()]); + data.copy_from_slice(buf.get(..data.len()).ok_or(PdError::InvalidParams)?); } Ok(ReturnValue::Success) } diff --git a/src/asynchronous/internal/mod.rs b/src/asynchronous/internal/mod.rs index c2fe08c..06c2882 100644 --- a/src/asynchronous/internal/mod.rs +++ b/src/asynchronous/internal/mod.rs @@ -2,10 +2,10 @@ use device_driver::AsyncRegisterInterface; use embedded_hal_async::i2c::I2c; use embedded_usb_pd::pdinfo::AltMode; -use embedded_usb_pd::pdo::{self, sink, source, ExpectedPdo}; +use embedded_usb_pd::pdo::{self, sink, source}; use embedded_usb_pd::{Error, LocalPortId, PdError}; -use crate::registers::rx_caps::EPR_PDO_START_INDEX; +use crate::registers::rx_caps::{RxCapsError, EPR_PDO_START_INDEX}; use crate::{ registers, warn, DeviceError, Mode, MAX_SUPPORTED_PORTS, PORT0, PORT1, TPS66993_NUM_PORTS, TPS66994_NUM_PORTS, }; @@ -45,10 +45,13 @@ impl device_driver::AsyncRegisterInterface for Port<'_, B> { buf[0] = address; buf[1] = data.len() as u8; - let _ = &buf[2..data.len() + 2].copy_from_slice(data); + let _ = &buf + .get_mut(2..data.len() + 2) + .ok_or(PdError::InvalidParams)? + .copy_from_slice(data); self.bus - .write(self.addr, &buf[..data.len() + 2]) + .write(self.addr, buf.get(..data.len() + 2).ok_or(PdError::InvalidParams)?) .await .map_err(Error::Bus) } @@ -69,7 +72,7 @@ impl device_driver::AsyncRegisterInterface for Port<'_, B> { } self.bus - .write_read(self.addr, ®, &mut buf[..full_len]) + .write_read(self.addr, ®, buf.get_mut(..full_len).ok_or(PdError::InvalidParams)?) .await .map_err(Error::Bus)?; @@ -88,7 +91,7 @@ impl device_driver::AsyncRegisterInterface for Port<'_, B> { // Controller is busy and can't respond PdError::Busy.into() } else { - data.copy_from_slice(&buf[1..data.len() + 1]); + data.copy_from_slice(buf.get(1..data.len() + 1).ok_or(PdError::InvalidParams)?); Ok(()) } } @@ -117,11 +120,7 @@ impl Tps6699x { /// Get the I2C address for a port fn port_addr(&self, port: LocalPortId) -> Result> { - if port.0 as usize >= self.num_ports { - PdError::InvalidPort.into() - } else { - Ok(self.addr[port.0 as usize]) - } + Ok(*self.addr.get(port.0 as usize).ok_or(PdError::InvalidPort)?) } /// Returns number of ports @@ -602,7 +601,7 @@ impl Tps6699x { register: u8, out_spr_pdos: &mut [T], out_epr_pdos: &mut [T], - ) -> Result<(usize, usize), DeviceError> { + ) -> Result<(usize, usize), DeviceError> { // Clamp to the maximum number of PDOs let num_pdos = if !out_epr_pdos.is_empty() { EPR_PDO_START_INDEX + out_epr_pdos.len() @@ -626,12 +625,16 @@ impl Tps6699x { let num_sprs = out_spr_pdos.len().min(rx_caps.num_valid_pdos() as usize); for (i, pdo) in out_spr_pdos.iter_mut().enumerate().take(num_sprs) { // SPR PDOs start at index 0 - *pdo = rx_caps[i]; + *pdo = *rx_caps + .get(i) + .ok_or(DeviceError::Error(Error::Pd(PdError::InvalidParams)))?; } let num_eprs = out_epr_pdos.len().min(rx_caps.num_valid_epr_pdos() as usize); for (i, pdo) in out_epr_pdos.iter_mut().enumerate().take(num_eprs) { - *pdo = rx_caps[EPR_PDO_START_INDEX + i]; + *pdo = *rx_caps + .get(EPR_PDO_START_INDEX + i) + .ok_or(DeviceError::Error(Error::Pd(PdError::InvalidParams)))?; } Ok((num_sprs, num_eprs)) @@ -645,7 +648,7 @@ impl Tps6699x { port: LocalPortId, out_spr_pdos: &mut [source::Pdo], out_epr_pdos: &mut [source::Pdo], - ) -> Result<(usize, usize), DeviceError> { + ) -> Result<(usize, usize), DeviceError> { self.get_rx_caps(port, registers::rx_caps::RX_SRC_ADDR, out_spr_pdos, out_epr_pdos) .await } @@ -658,7 +661,7 @@ impl Tps6699x { port: LocalPortId, out_spr_pdos: &mut [sink::Pdo], out_epr_pdos: &mut [sink::Pdo], - ) -> Result<(usize, usize), DeviceError> { + ) -> Result<(usize, usize), DeviceError> { self.get_rx_caps(port, registers::rx_caps::RX_SNK_ADDR, out_spr_pdos, out_epr_pdos) .await } diff --git a/src/asynchronous/interrupt.rs b/src/asynchronous/interrupt.rs index 935b253..e724536 100644 --- a/src/asynchronous/interrupt.rs +++ b/src/asynchronous/interrupt.rs @@ -27,12 +27,8 @@ pub trait InterruptController { port: LocalPortId, enabled: bool, ) -> Result> { - if port.0 as usize >= MAX_SUPPORTED_PORTS { - return PdError::InvalidPort.into(); - } - let mut state = self.interrupts_enabled().await?; - state[port.0 as usize] = enabled; + *state.get_mut(port.0 as usize).ok_or(PdError::InvalidPort)? = enabled; self.enable_interrupts_guarded(state).await } diff --git a/src/command/mod.rs b/src/command/mod.rs index 8f7155b..78d04f1 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -4,27 +4,16 @@ use bincode::error::{DecodeError, EncodeError}; use bincode::{Decode, Encode}; use embedded_usb_pd::PdError; +use crate::u32_from_str; + pub mod gcdm; pub mod muxr; pub mod trig; pub mod vdms; -/// Length of a command -const CMD_LEN: usize = 4; - /// TaskResult is only defined for lower 4 bits pub const CMD_4CC_TASK_RETURN_CODE_MASK: u8 = 0x0F; -/// Converts a 4-byte string into a u32 -const fn u32_from_str(value: &str) -> u32 { - if value.len() != CMD_LEN { - panic!("Invalid command string") - } - - let bytes = value.as_bytes(); - u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]).to_le() -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] #[repr(u32)] @@ -32,27 +21,27 @@ pub enum Command { /// Previous command succeeded Success = 0, /// Invalid Command - Invalid = u32_from_str("!CMD"), + Invalid = u32_from_str(*b"!CMD"), /// Reset command - Gaid = u32_from_str("GAID"), + Gaid = u32_from_str(*b"GAID"), /// Tomcat firmware update mode enter - Tfus = u32_from_str("TFUs"), + Tfus = u32_from_str(*b"TFUs"), /// Tomcat firmware update mode init - Tfui = u32_from_str("TFUi"), + Tfui = u32_from_str(*b"TFUi"), /// Tomcat firmware update mode query - Tfuq = u32_from_str("TFUq"), + Tfuq = u32_from_str(*b"TFUq"), /// Tomcat firmware update mode exit - Tfue = u32_from_str("TFUe"), + Tfue = u32_from_str(*b"TFUe"), /// Tomcat firmware update data - Tfud = u32_from_str("TFUd"), + Tfud = u32_from_str(*b"TFUd"), /// Tomcat firmware update complete - Tfuc = u32_from_str("TFUc"), + Tfuc = u32_from_str(*b"TFUc"), /// System ready to sink - Srdy = u32_from_str("SRDY"), + Srdy = u32_from_str(*b"SRDY"), /// SRDY reset - Sryr = u32_from_str("SRYR"), + Sryr = u32_from_str(*b"SRYR"), /// Re-evaluate the Autonegotiate Sink register. /// @@ -61,10 +50,10 @@ pub enum Command { /// /// # Output /// [`ReturnValue`] - Aneg = u32_from_str("ANeg"), + Aneg = u32_from_str(*b"ANeg"), /// Trigger an Input GPIO event - Trig = u32_from_str("Trig"), + Trig = u32_from_str(*b"Trig"), /// Clear the dead battery flag. /// @@ -73,7 +62,7 @@ pub enum Command { /// /// # Output /// [`ReturnValue`] - Dbfg = u32_from_str("DBfg"), + Dbfg = u32_from_str(*b"DBfg"), /// Repeat transactions on I2C3m under certain conditions. /// @@ -82,7 +71,7 @@ pub enum Command { /// /// # Output /// [`ReturnValue`] - Muxr = u32_from_str("MuxR"), + Muxr = u32_from_str(*b"MuxR"), /// PD Data Reset /// @@ -91,7 +80,7 @@ pub enum Command { /// /// # Output /// [`ReturnValue`] - Drst = u32_from_str("DRST"), + Drst = u32_from_str(*b"DRST"), /// Send VDM. /// @@ -100,7 +89,7 @@ pub enum Command { /// /// # Output /// None - VDMs = u32_from_str("VDMs"), + VDMs = u32_from_str(*b"VDMs"), /// Execute a UCSI command /// @@ -109,7 +98,7 @@ pub enum Command { /// /// # Output /// [`embedded_usb_pd::ucsi::lpm::ResponseData`] - Ucsi = u32_from_str("UCSI"), + Ucsi = u32_from_str(*b"UCSI"), /// Get custom discovered modes /// @@ -118,7 +107,7 @@ pub enum Command { /// /// # Output /// [`gcdm::DiscoveredMode`] - GCdm = u32_from_str("GCdm"), + GCdm = u32_from_str(*b"GCdm"), } impl TryFrom for Command { diff --git a/src/lib.rs b/src/lib.rs index 4b2b3e1..b8aeec7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,15 +56,15 @@ impl From> for embedded_usb_pd::Error { #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Mode { /// Boot mode - Boot = u32_from_str("BOOT"), + Boot = u32_from_str(*b"BOOT"), /// Firmware corrupt on both banks - F211 = u32_from_str("F211"), + F211 = u32_from_str(*b"F211"), /// Before app config - App0 = u32_from_str("APP0"), + App0 = u32_from_str(*b"APP0"), /// After app config - App1 = u32_from_str("APP1"), + App1 = u32_from_str(*b"APP1"), /// App FW waiting for power - Wtpr = u32_from_str("WTPR"), + Wtpr = u32_from_str(*b"WTPR"), } impl PartialEq for Mode { @@ -102,12 +102,8 @@ impl Into<[u8; 4]> for Mode { const U32_STR_LEN: usize = 4; /// Converts a 4-byte string into a u32 -pub(crate) const fn u32_from_str(value: &str) -> u32 { - if value.len() != U32_STR_LEN { - panic!("Invalid command string") - } - let bytes = value.as_bytes(); - u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]).to_le() +pub(crate) const fn u32_from_str(bytes: [u8; U32_STR_LEN]) -> u32 { + u32::from_le_bytes(bytes).to_le() } /// Common unit test functions diff --git a/src/registers/autonegotiate_sink.rs b/src/registers/autonegotiate_sink.rs index d25eb7b..f7a8751 100644 --- a/src/registers/autonegotiate_sink.rs +++ b/src/registers/autonegotiate_sink.rs @@ -162,7 +162,8 @@ impl From for PpsRequestInterval { 0x1 => PpsRequestInterval::FourSeconds, 0x2 => PpsRequestInterval::TwoSeconds, 0x3 => PpsRequestInterval::OneSecond, - _ => unreachable!("Masked value should always be in range 0-3"), + // Panic safety: All possible u8 values are unit tested + _ => unreachable!(), } } } @@ -678,4 +679,18 @@ mod tests { let bytes = actual.as_bytes(); assert_eq!(bytes, &AutonegotiateSink::DEFAULT); } + + #[test] + fn test_from_pps_request_interval() { + // Test simple values + assert_eq!(PpsRequestInterval::EightSeconds, 0x0.into()); + assert_eq!(PpsRequestInterval::FourSeconds, 0x1.into()); + assert_eq!(PpsRequestInterval::TwoSeconds, 0x2.into()); + assert_eq!(PpsRequestInterval::OneSecond, 0x3.into()); + + // Verify no panics + for v in u8::MIN..=u8::MAX { + let _ = PpsRequestInterval::from(v); + } + } } diff --git a/src/registers/rx_caps.rs b/src/registers/rx_caps.rs index 6125e4d..47d4ee8 100644 --- a/src/registers/rx_caps.rs +++ b/src/registers/rx_caps.rs @@ -1,5 +1,3 @@ -use core::ops::{Index, IndexMut}; - use bitfield::bitfield; use embedded_usb_pd::pdo::{sink, source, Common, ExpectedPdo, RoleCommon}; @@ -114,32 +112,30 @@ impl RxCaps { self.last_src_cap_is_epr = is_epr; self } -} -impl Index for RxCaps { - type Output = T; + /// Checked indexing into the PDOs + pub fn get(&self, index: usize) -> Option<&T> { + self.pdos.get(index) + } - fn index(&self, index: usize) -> &Self::Output { - if index < TOTAL_PDOS { - &self.pdos[index] - } else { - panic!("Index out of bounds: {}", index); - } + /// Checked mutable indexing into the PDOs + pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { + self.pdos.get_mut(index) } } -impl IndexMut for RxCaps { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - if index < TOTAL_PDOS { - &mut self.pdos[index] - } else { - panic!("Index out of bounds: {}", index); - } - } +/// Error type for functions that deal with received capabilities +#[derive(Copy, Clone, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RxCapsError { + /// PDO conversion error + ExpectedPdo(ExpectedPdo), + /// Invalid PDO index accessed, contains (requested, max) + InvalidPdoIndex(usize, usize), } impl TryFrom<[u8; LEN]> for RxCaps { - type Error = ExpectedPdo; + type Error = RxCapsError; fn try_from(raw: [u8; LEN]) -> Result { let raw = RxCapsRaw(raw); @@ -158,8 +154,9 @@ impl TryFrom<[u8; LEN]> for RxCaps { 4 => raw.pdo4(), 5 => raw.pdo5(), 6 => raw.pdo6(), - _ => unreachable!(), - })?; + _ => return Err(RxCapsError::InvalidPdoIndex(i, NUM_SPR_PDOS)), + }) + .map_err(RxCapsError::ExpectedPdo)?; } // Decode only valid EPR PDOs @@ -174,8 +171,9 @@ impl TryFrom<[u8; LEN]> for RxCaps { 1 => raw.epr_pdo1(), 2 => raw.epr_pdo2(), 3 => raw.epr_pdo3(), - _ => unreachable!(), - })?; + _ => return Err(RxCapsError::InvalidPdoIndex(i, NUM_EPR_PDOS)), + }) + .map_err(RxCapsError::ExpectedPdo)?; } Ok(RxCaps { @@ -220,8 +218,11 @@ mod test { let rx_src_caps = RxSrcCaps::try_from(buf).unwrap(); assert_eq!(rx_src_caps.num_valid_pdos(), 2); assert_eq!(rx_src_caps.num_valid_epr_pdos(), 1); - assert_eq!(rx_src_caps[0], TEST_SRC_PDO_FIXED_5V3A); - assert_eq!(rx_src_caps[1], TEST_SRC_PDO_FIXED_5V1A5); - assert_eq!(rx_src_caps[EPR_PDO_START_INDEX], TEST_SRC_EPR_PDO_FIXED_28V5A); + assert_eq!(*rx_src_caps.get(0).unwrap(), TEST_SRC_PDO_FIXED_5V3A); + assert_eq!(*rx_src_caps.get(1).unwrap(), TEST_SRC_PDO_FIXED_5V1A5); + assert_eq!( + *rx_src_caps.get(EPR_PDO_START_INDEX).unwrap(), + TEST_SRC_EPR_PDO_FIXED_28V5A + ); } } diff --git a/src/stream.rs b/src/stream.rs index 6dbb881..34b0f2d 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -83,6 +83,8 @@ impl SeekingStream { ); // Still waiting for a particular byte self.position += data.len(); + // Panic safety: this will never panic because we can always take a slice with a length of zero + #[allow(clippy::indexing_slicing)] &data[0..0] } } From 355ddc659e3a70030f2944faf3216ae9a3e639bf Mon Sep 17 00:00:00 2001 From: Robert Zieba Date: Mon, 15 Dec 2025 09:34:35 -0800 Subject: [PATCH 2/2] Add InvalidPdoIndex error struct --- src/registers/rx_caps.rs | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/registers/rx_caps.rs b/src/registers/rx_caps.rs index 47d4ee8..7439e6b 100644 --- a/src/registers/rx_caps.rs +++ b/src/registers/rx_caps.rs @@ -124,6 +124,14 @@ impl RxCaps { } } +/// Struct for [`RxCapsError::ExpectedPdo`] +#[derive(Copy, Clone, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct InvalidPdoIndex { + pub requested: usize, + pub max: usize, +} + /// Error type for functions that deal with received capabilities #[derive(Copy, Clone, Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] @@ -131,7 +139,7 @@ pub enum RxCapsError { /// PDO conversion error ExpectedPdo(ExpectedPdo), /// Invalid PDO index accessed, contains (requested, max) - InvalidPdoIndex(usize, usize), + InvalidPdoIndex(InvalidPdoIndex), } impl TryFrom<[u8; LEN]> for RxCaps { @@ -154,7 +162,12 @@ impl TryFrom<[u8; LEN]> for RxCaps { 4 => raw.pdo4(), 5 => raw.pdo5(), 6 => raw.pdo6(), - _ => return Err(RxCapsError::InvalidPdoIndex(i, NUM_SPR_PDOS)), + _ => { + return Err(RxCapsError::InvalidPdoIndex(InvalidPdoIndex { + requested: i, + max: NUM_SPR_PDOS, + })) + } }) .map_err(RxCapsError::ExpectedPdo)?; } @@ -171,7 +184,12 @@ impl TryFrom<[u8; LEN]> for RxCaps { 1 => raw.epr_pdo1(), 2 => raw.epr_pdo2(), 3 => raw.epr_pdo3(), - _ => return Err(RxCapsError::InvalidPdoIndex(i, NUM_EPR_PDOS)), + _ => { + return Err(RxCapsError::InvalidPdoIndex(InvalidPdoIndex { + requested: i, + max: NUM_EPR_PDOS, + })) + } }) .map_err(RxCapsError::ExpectedPdo)?; }