Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions src/asynchronous/fw_update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,11 @@ impl<T: UpdateTarget> BorrowedUpdaterInProgress<T> {
controllers: &mut [&mut T],
data: &[u8],
) -> Result<(), Error<T::BusError>> {
if controllers.is_empty() {
return Err(PdError::InvalidParams.into());
}

trace!("Controllers: Sending burst write");
let update_args = self.update_args.ok_or(Error::Pd(PdError::InvalidParams))?;
if let Err(e) = controllers[0]
if let Err(e) = controllers
.get_mut(0)
.ok_or(PdError::InvalidParams)?
.fw_update_burst_write(update_args.broadcast_u16_address as u8, data)
.await
{
Expand Down Expand Up @@ -434,7 +432,10 @@ impl<T: UpdateTarget> BorrowedUpdaterInProgress<T> {
read_result.read_data,
);

self.args_buffer[current..current + read_len].copy_from_slice(read_result.read_data);
self.args_buffer
.get_mut(current..current + read_len)
.ok_or(PdError::InvalidParams)?
.copy_from_slice(read_result.read_data);

if read_result.is_complete() {
// We have the full header metadata
Expand Down Expand Up @@ -498,7 +499,10 @@ impl<T: UpdateTarget> BorrowedUpdaterInProgress<T> {
self.block_args = None;
let current = read_result.read_state.current;
let read_len = read_result.read_data.len();
self.args_buffer[current..current + read_len].copy_from_slice(read_result.read_data);
self.args_buffer
.get_mut(current..current + read_len)
.ok_or(PdError::InvalidParams)?
.copy_from_slice(read_result.read_data);

if read_result.is_complete() {
// We have the full header metadata
Expand Down Expand Up @@ -573,7 +577,10 @@ impl<T: UpdateTarget> BorrowedUpdaterInProgress<T> {
) -> Result<Option<SeekOperation>, Error<T::BusError>> {
let current = read_result.read_state.current;
let read_len = read_result.read_data.len();
self.args_buffer[current..current + read_len].copy_from_slice(read_result.read_data);
self.args_buffer
.get_mut(current..current + read_len)
.ok_or(PdError::InvalidParams)?
.copy_from_slice(read_result.read_data);
self.fw_update_burst_write(controllers, read_result.read_data).await?;
if read_result.is_complete() {
// We have the full image size
Expand Down Expand Up @@ -702,14 +709,22 @@ pub async fn perform_fw_update_borrowed<T: UpdateTarget>(

// Disable all interrupts while we're entering FW update mode
// These go in the second half of the interrupt_guards array so they get dropped last
disable_all_interrupts(controllers, &mut interrupt_guards[half..]).await?;
disable_all_interrupts(
controllers,
interrupt_guards.get_mut(half..).ok_or(PdError::InvalidParams)?,
)
.await?;
info!("Starting update");
let result = updater.start_fw_update(controllers, delay).await;
info!("Update started");

// Re-enable interrupts on port 0 only
// These go in the first half of the interrupt_guards array so they get dropped first
enable_port0_interrupts(controllers, &mut interrupt_guards[0..half]).await?;
enable_port0_interrupts(
controllers,
interrupt_guards.get_mut(0..half).ok_or(PdError::InvalidParams)?,
)
.await?;

match result {
Err(e) => {
Expand Down
4 changes: 2 additions & 2 deletions src/asynchronous/internal/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ impl<B: I2c> Tps6699x<B> {
debug!("read_command_result: ret: {:?}", ret);
// Overwrite return value
if let Some(data) = data {
data.copy_from_slice(&buf[1..=data.len()]);
data.copy_from_slice(buf.get(1..=data.len()).ok_or(PdError::InvalidParams)?);
}
Ok(ret)
} else {
// No return value to check
debug!("read_command_result: Done");
if let Some(data) = data {
data.copy_from_slice(&buf[..data.len()]);
data.copy_from_slice(buf.get(..data.len()).ok_or(PdError::InvalidParams)?);
}
Ok(ReturnValue::Success)
}
Expand Down
35 changes: 19 additions & 16 deletions src/asynchronous/internal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
use device_driver::AsyncRegisterInterface;
use embedded_hal_async::i2c::I2c;
use embedded_usb_pd::pdinfo::AltMode;
use embedded_usb_pd::pdo::{self, sink, source, ExpectedPdo};
use embedded_usb_pd::pdo::{self, sink, source};
use embedded_usb_pd::{Error, LocalPortId, PdError};

use crate::registers::rx_caps::EPR_PDO_START_INDEX;
use crate::registers::rx_caps::{RxCapsError, EPR_PDO_START_INDEX};
use crate::{
registers, warn, DeviceError, Mode, MAX_SUPPORTED_PORTS, PORT0, PORT1, TPS66993_NUM_PORTS, TPS66994_NUM_PORTS,
};
Expand Down Expand Up @@ -45,10 +45,13 @@ impl<B: I2c> device_driver::AsyncRegisterInterface for Port<'_, B> {

buf[0] = address;
buf[1] = data.len() as u8;
let _ = &buf[2..data.len() + 2].copy_from_slice(data);
let _ = &buf
.get_mut(2..data.len() + 2)
.ok_or(PdError::InvalidParams)?
.copy_from_slice(data);

self.bus
.write(self.addr, &buf[..data.len() + 2])
.write(self.addr, buf.get(..data.len() + 2).ok_or(PdError::InvalidParams)?)
.await
.map_err(Error::Bus)
}
Expand All @@ -69,7 +72,7 @@ impl<B: I2c> device_driver::AsyncRegisterInterface for Port<'_, B> {
}

self.bus
.write_read(self.addr, &reg, &mut buf[..full_len])
.write_read(self.addr, &reg, buf.get_mut(..full_len).ok_or(PdError::InvalidParams)?)
.await
.map_err(Error::Bus)?;

Expand All @@ -88,7 +91,7 @@ impl<B: I2c> device_driver::AsyncRegisterInterface for Port<'_, B> {
// Controller is busy and can't respond
PdError::Busy.into()
} else {
data.copy_from_slice(&buf[1..data.len() + 1]);
data.copy_from_slice(buf.get(1..data.len() + 1).ok_or(PdError::InvalidParams)?);
Ok(())
}
}
Expand Down Expand Up @@ -117,11 +120,7 @@ impl<B: I2c> Tps6699x<B> {

/// Get the I2C address for a port
fn port_addr(&self, port: LocalPortId) -> Result<u8, Error<B::Error>> {
if port.0 as usize >= self.num_ports {
PdError::InvalidPort.into()
} else {
Ok(self.addr[port.0 as usize])
}
Ok(*self.addr.get(port.0 as usize).ok_or(PdError::InvalidPort)?)
}

/// Returns number of ports
Expand Down Expand Up @@ -602,7 +601,7 @@ impl<B: I2c> Tps6699x<B> {
register: u8,
out_spr_pdos: &mut [T],
out_epr_pdos: &mut [T],
) -> Result<(usize, usize), DeviceError<B::Error, ExpectedPdo>> {
) -> Result<(usize, usize), DeviceError<B::Error, RxCapsError>> {
// Clamp to the maximum number of PDOs
let num_pdos = if !out_epr_pdos.is_empty() {
EPR_PDO_START_INDEX + out_epr_pdos.len()
Expand All @@ -626,12 +625,16 @@ impl<B: I2c> Tps6699x<B> {
let num_sprs = out_spr_pdos.len().min(rx_caps.num_valid_pdos() as usize);
for (i, pdo) in out_spr_pdos.iter_mut().enumerate().take(num_sprs) {
// SPR PDOs start at index 0
*pdo = rx_caps[i];
*pdo = *rx_caps
.get(i)
.ok_or(DeviceError::Error(Error::Pd(PdError::InvalidParams)))?;
}

let num_eprs = out_epr_pdos.len().min(rx_caps.num_valid_epr_pdos() as usize);
for (i, pdo) in out_epr_pdos.iter_mut().enumerate().take(num_eprs) {
*pdo = rx_caps[EPR_PDO_START_INDEX + i];
*pdo = *rx_caps
.get(EPR_PDO_START_INDEX + i)
.ok_or(DeviceError::Error(Error::Pd(PdError::InvalidParams)))?;
}

Ok((num_sprs, num_eprs))
Expand All @@ -645,7 +648,7 @@ impl<B: I2c> Tps6699x<B> {
port: LocalPortId,
out_spr_pdos: &mut [source::Pdo],
out_epr_pdos: &mut [source::Pdo],
) -> Result<(usize, usize), DeviceError<B::Error, ExpectedPdo>> {
) -> Result<(usize, usize), DeviceError<B::Error, RxCapsError>> {
self.get_rx_caps(port, registers::rx_caps::RX_SRC_ADDR, out_spr_pdos, out_epr_pdos)
.await
}
Expand All @@ -658,7 +661,7 @@ impl<B: I2c> Tps6699x<B> {
port: LocalPortId,
out_spr_pdos: &mut [sink::Pdo],
out_epr_pdos: &mut [sink::Pdo],
) -> Result<(usize, usize), DeviceError<B::Error, ExpectedPdo>> {
) -> Result<(usize, usize), DeviceError<B::Error, RxCapsError>> {
self.get_rx_caps(port, registers::rx_caps::RX_SNK_ADDR, out_spr_pdos, out_epr_pdos)
.await
}
Expand Down
6 changes: 1 addition & 5 deletions src/asynchronous/interrupt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@ pub trait InterruptController {
port: LocalPortId,
enabled: bool,
) -> Result<Self::Guard, Error<Self::BusError>> {
if port.0 as usize >= MAX_SUPPORTED_PORTS {
return PdError::InvalidPort.into();
}

let mut state = self.interrupts_enabled().await?;
state[port.0 as usize] = enabled;
*state.get_mut(port.0 as usize).ok_or(PdError::InvalidPort)? = enabled;
self.enable_interrupts_guarded(state).await
}

Expand Down
51 changes: 20 additions & 31 deletions src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,55 +4,44 @@ use bincode::error::{DecodeError, EncodeError};
use bincode::{Decode, Encode};
use embedded_usb_pd::PdError;

use crate::u32_from_str;

pub mod gcdm;
pub mod muxr;
pub mod trig;
pub mod vdms;

/// Length of a command
const CMD_LEN: usize = 4;

/// TaskResult is only defined for lower 4 bits
pub const CMD_4CC_TASK_RETURN_CODE_MASK: u8 = 0x0F;

/// Converts a 4-byte string into a u32
const fn u32_from_str(value: &str) -> u32 {
if value.len() != CMD_LEN {
panic!("Invalid command string")
}

let bytes = value.as_bytes();
u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]).to_le()
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[repr(u32)]
pub enum Command {
/// Previous command succeeded
Success = 0,
/// Invalid Command
Invalid = u32_from_str("!CMD"),
Invalid = u32_from_str(*b"!CMD"),
/// Reset command
Gaid = u32_from_str("GAID"),
Gaid = u32_from_str(*b"GAID"),

/// Tomcat firmware update mode enter
Tfus = u32_from_str("TFUs"),
Tfus = u32_from_str(*b"TFUs"),
/// Tomcat firmware update mode init
Tfui = u32_from_str("TFUi"),
Tfui = u32_from_str(*b"TFUi"),
/// Tomcat firmware update mode query
Tfuq = u32_from_str("TFUq"),
Tfuq = u32_from_str(*b"TFUq"),
/// Tomcat firmware update mode exit
Tfue = u32_from_str("TFUe"),
Tfue = u32_from_str(*b"TFUe"),
/// Tomcat firmware update data
Tfud = u32_from_str("TFUd"),
Tfud = u32_from_str(*b"TFUd"),
/// Tomcat firmware update complete
Tfuc = u32_from_str("TFUc"),
Tfuc = u32_from_str(*b"TFUc"),

/// System ready to sink
Srdy = u32_from_str("SRDY"),
Srdy = u32_from_str(*b"SRDY"),
/// SRDY reset
Sryr = u32_from_str("SRYR"),
Sryr = u32_from_str(*b"SRYR"),

/// Re-evaluate the Autonegotiate Sink register.
///
Expand All @@ -61,10 +50,10 @@ pub enum Command {
///
/// # Output
/// [`ReturnValue`]
Aneg = u32_from_str("ANeg"),
Aneg = u32_from_str(*b"ANeg"),

/// Trigger an Input GPIO event
Trig = u32_from_str("Trig"),
Trig = u32_from_str(*b"Trig"),

/// Clear the dead battery flag.
///
Expand All @@ -73,7 +62,7 @@ pub enum Command {
///
/// # Output
/// [`ReturnValue`]
Dbfg = u32_from_str("DBfg"),
Dbfg = u32_from_str(*b"DBfg"),

/// Repeat transactions on I2C3m under certain conditions.
///
Expand All @@ -82,7 +71,7 @@ pub enum Command {
///
/// # Output
/// [`ReturnValue`]
Muxr = u32_from_str("MuxR"),
Muxr = u32_from_str(*b"MuxR"),

/// PD Data Reset
///
Expand All @@ -91,7 +80,7 @@ pub enum Command {
///
/// # Output
/// [`ReturnValue`]
Drst = u32_from_str("DRST"),
Drst = u32_from_str(*b"DRST"),

/// Send VDM.
///
Expand All @@ -100,7 +89,7 @@ pub enum Command {
///
/// # Output
/// None
VDMs = u32_from_str("VDMs"),
VDMs = u32_from_str(*b"VDMs"),

/// Execute a UCSI command
///
Expand All @@ -109,7 +98,7 @@ pub enum Command {
///
/// # Output
/// [`embedded_usb_pd::ucsi::lpm::ResponseData`]
Ucsi = u32_from_str("UCSI"),
Ucsi = u32_from_str(*b"UCSI"),

/// Get custom discovered modes
///
Expand All @@ -118,7 +107,7 @@ pub enum Command {
///
/// # Output
/// [`gcdm::DiscoveredMode`]
GCdm = u32_from_str("GCdm"),
GCdm = u32_from_str(*b"GCdm"),
}

impl TryFrom<u32> for Command {
Expand Down
18 changes: 7 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ impl<BE, T> From<DeviceError<BE, T>> for embedded_usb_pd::Error<BE> {
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum Mode {
/// Boot mode
Boot = u32_from_str("BOOT"),
Boot = u32_from_str(*b"BOOT"),
/// Firmware corrupt on both banks
F211 = u32_from_str("F211"),
F211 = u32_from_str(*b"F211"),
/// Before app config
App0 = u32_from_str("APP0"),
App0 = u32_from_str(*b"APP0"),
/// After app config
App1 = u32_from_str("APP1"),
App1 = u32_from_str(*b"APP1"),
/// App FW waiting for power
Wtpr = u32_from_str("WTPR"),
Wtpr = u32_from_str(*b"WTPR"),
}

impl PartialEq<u32> for Mode {
Expand Down Expand Up @@ -102,12 +102,8 @@ impl Into<[u8; 4]> for Mode {

const U32_STR_LEN: usize = 4;
/// Converts a 4-byte string into a u32
pub(crate) const fn u32_from_str(value: &str) -> u32 {
if value.len() != U32_STR_LEN {
panic!("Invalid command string")
}
let bytes = value.as_bytes();
u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]).to_le()
pub(crate) const fn u32_from_str(bytes: [u8; U32_STR_LEN]) -> u32 {
u32::from_le_bytes(bytes).to_le()
}

/// Common unit test functions
Expand Down
Loading