diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dc0b14..4e78d8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,21 @@ +# 0.4.0 (Nov 26th, 2025) + +## Breaking Changes + +* Published field streams now yield the current value on first poll, then subsequent changes. +* Published field streams' item type is now the raw field type (e.g., `State`) instead of + `*Changed` struct with `previous` and `new` fields. +* The `pub_setter` sub-attribute on `publish` has been removed. Use the new independent `setter` + attribute instead (e.g., `#[controller(publish, setter)]`). + +## New Features + +* New `getter` attribute for fields: generates a client-side getter method. Supports custom naming + via `#[controller(getter = "custom_name")]`. +* New `setter` attribute for fields: generates a client-side setter method independent of `publish`. + Supports custom naming via `#[controller(setter = "custom_name")]`. Can be combined with `publish` + to also broadcast changes. + # 0.3.0 (Nov 25th, 2025) * Macro now operates on a module. This allows the macro to have a visibility on both the struct and diff --git a/CLAUDE.md b/CLAUDE.md index a8bac12..057d54e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,8 +8,8 @@ This is a procedural macro crate that provides the `#[controller]` attribute mac * A controller struct that manages peripheral state. * Client API for sending commands to the controller. -* Signal mechanism for broadcasting events. -* Pub/sub system for state change notifications. +* Signal mechanism for broadcasting events (PubSubChannel). +* Watch-based subscriptions for state change notifications (yields current value first). The macro is applied to a module containing both the controller struct definition and its impl block, allowing coordinated code generation of the controller infrastructure, client API, and communication channels. @@ -52,18 +52,31 @@ The `expand_module()` function: * Combines the generated code back into the module structure along with any other items. Channel capacities and subscriber limits are also defined here: -* `ALL_CHANNEL_CAPACITY`: 8 -* `SIGNAL_CHANNEL_CAPACITY`: 8 -* `BROADCAST_MAX_PUBLISHERS`: 1 -* `BROADCAST_MAX_SUBSCRIBERS`: 16 +* `ALL_CHANNEL_CAPACITY`: 8 (method/getter/setter request channels) +* `SIGNAL_CHANNEL_CAPACITY`: 8 (signal PubSubChannel queue size) +* `BROADCAST_MAX_PUBLISHERS`: 1 (signals only) +* `BROADCAST_MAX_SUBSCRIBERS`: 16 (Watch for published fields, PubSubChannel for signals) ### Struct Processing (`src/controller/item_struct.rs`) -Processes the controller struct definition. For fields marked with `#[controller(publish)]`: -* Adds publisher fields to the struct. -* Generates setters (`set_`) that broadcast changes. -* Creates `` stream type and `Changed` event struct. +Processes the controller struct definition. Supports three field attributes: -The generated `new()` method initializes both user fields and generated publisher fields. +**`#[controller(publish)]`** - Enables state change subscriptions: +* Uses `embassy_sync::watch::Watch` channel (stores latest value). +* Generates internal setter (`set_`) that broadcasts changes. +* Creates `` subscriber stream type. +* Stream yields current value on first poll, then subsequent changes. + +**`#[controller(getter)]` or `#[controller(getter = "name")]`**: +* Generates a client-side getter method to read the field value. +* Default name is the field name; custom name can be specified. + +**`#[controller(setter)]` or `#[controller(setter = "name")]`**: +* Generates a client-side setter method to update the field value. +* Default name is `set_`; custom name can be specified. +* Can be combined with `publish` to also broadcast changes. + +The generated `new()` method initializes both user fields and generated sender fields, and sends +initial values to Watch channels so subscribers get them immediately. ### Impl Processing (`src/controller/item_impl.rs`) Processes the controller impl block. Distinguishes between: @@ -75,11 +88,18 @@ Processes the controller impl block. Distinguishes between: **Signal methods** (marked with `#[controller(signal)]`): * Methods have no body in the user's impl block. +* Uses `embassy_sync::pubsub::PubSubChannel` for broadcast. * Generates method implementation that broadcasts to subscribers. * Creates `` stream type and `Args` struct. * Signal methods are NOT exposed in the client API (controller emits them directly). -The generated `run()` method contains a `select_biased!` loop that receives method calls from clients and dispatches them to the user's implementations. +**Getter/setter methods** (from struct field attributes): +* Receives getter/setter field info from struct processing. +* Generates client-side getter methods that request current field value. +* Generates client-side setter methods that update field value (and broadcast if published). + +The generated `run()` method contains a `select_biased!` loop that receives method calls from +clients and dispatches them to the user's implementations. ### Utilities (`src/util.rs`) Case conversion functions (`pascal_to_snake_case`, `snake_to_pascal_case`) used for generating type and method names. @@ -97,5 +117,7 @@ Dev dependencies include `embassy-executor` and `embassy-time` for testing. * Singleton operation: multiple controller instances interfere with each other. * Methods must be async and cannot use reference parameters/return types. * Maximum 16 subscribers per state/signal stream. -* Published fields must implement `Clone` and `Debug`. -* Streams must be continuously polled or notifications are missed. +* Published fields must implement `Clone`. +* Published field streams yield current value on first poll; intermediate values may be missed if + not polled between changes. +* Signal streams must be continuously polled or notifications are missed. diff --git a/Cargo.lock b/Cargo.lock index 8e1e255..ada66a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -230,7 +230,7 @@ dependencies = [ [[package]] name = "firmware-controller" -version = "0.3.0" +version = "0.4.0" dependencies = [ "critical-section", "embassy-executor", diff --git a/Cargo.toml b/Cargo.toml index 0a75c33..c5861a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "firmware-controller" description = "Controller to decouple interactions between components in a no_std environment." -version = "0.3.0" +version = "0.4.0" edition = "2021" authors = [ "Zeeshan Ali Khan ", diff --git a/README.md b/README.md index 914896e..c457ba9 100644 --- a/README.md +++ b/README.md @@ -109,17 +109,21 @@ async fn client() { use embassy_time::{Timer, Duration}; let mut client = ControllerClient::new(); - let state_changed = client.receive_state_changed().unwrap().map(Either::Left); + let mut state_stream = client.receive_state_changed().unwrap(); let error_stream = client.receive_power_error().unwrap().map(Either::Right); + + // First poll returns the current (initial) state. + let initial_state = state_stream.next().await.unwrap(); + assert_eq!(initial_state, State::Disabled); + + // Now combine streams for event handling. + let state_changed = state_stream.map(Either::Left); let mut stream = select(state_changed, error_stream); client.enable_power().await.unwrap(); while let Some(event) = stream.next().await { match event { - Either::Left(ControllerStateChanged { - new: State::Enabled, - .. - }) => { + Either::Left(State::Enabled) => { // This is fine in this very simple example where we've only one client in a single // task. In a real-world application, you should ensure that the stream is polled // continuously. Otherwise, you might miss notifications. @@ -127,10 +131,7 @@ async fn client() { client.disable_power().await.unwrap(); } - Either::Left(ControllerStateChanged { - new: State::Disabled, - .. - }) => { + Either::Left(State::Disabled) => { Timer::after(Duration::from_secs(1)).await; client.enable_power().await.unwrap(); @@ -169,11 +170,16 @@ methods: controller and return the results. * For each `published` field: * `receive__changed()` method (e.g., `receive_state_changed()`) that returns a - stream of state changes. The stream yields `Changed` - structs (e.g., `ControllerStateChanged`) containing `previous` and `new` fields. - * If the field is marked with `#[controller(publish(pub_setter))]`, a public - `set_()` method (e.g., `set_state()`) is also generated on the client, allowing - external code to update the field value through the client API. + stream of state values. The first value yielded is the current state at subscription time, + and subsequent values are emitted when the field changes. The stream yields values of the + field type directly (e.g., `State`). +* For each field with a `getter` attribute (e.g., `#[controller(getter)]` or + `#[controller(getter = "custom_name")]`), a getter method is generated on the client. The default + name is the field name; a custom name can be specified. +* For each field with a `setter` attribute (e.g., `#[controller(setter)]` or + `#[controller(setter = "custom_name")]`), a public setter method is generated on the client, + allowing external code to update the field value through the client API. The default setter + name is `set_()`. This can be combined with `publish` to also broadcast changes. * For each `signal` method: * `receive_()` method (e.g., `receive_power_error()`) that returns a stream of signal events. The stream yields `Args` structs @@ -196,6 +202,8 @@ The `controller` macro assumes that you have the following dependencies in your * Methods must be async. * The maximum number of subscribers state change and signal streams is 16. We plan to provide an attribute to make this configurable in the future. -* The type of all published fields must implement `Clone` and `Debug`. -* The signal and published fields' streams must be continuely polled. Otherwise notifications will - be missed. +* The type of all published fields must implement `Clone`. +* Published field streams yield the current value on first poll, then subsequent changes. Only the + latest value is stored; intermediate values may be missed if the stream is not polled between + changes. +* Signal streams must be continuously polled. Otherwise notifications will be missed. diff --git a/src/controller/item_impl.rs b/src/controller/item_impl.rs index 0934d07..dbb95fa 100644 --- a/src/controller/item_impl.rs +++ b/src/controller/item_impl.rs @@ -7,11 +7,14 @@ use syn::{ Attribute, Ident, ImplItem, ImplItemFn, ItemImpl, Result, Signature, Token, Visibility, }; -use crate::{controller::item_struct::PublishedFieldInfo, util::snake_to_pascal_case}; +use crate::controller::item_struct::{GetterFieldInfo, PublishedFieldInfo, SetterFieldInfo}; +use crate::util::snake_to_pascal_case; pub(crate) fn expand( mut input: ItemImpl, published_fields: &[PublishedFieldInfo], + getter_fields: &[GetterFieldInfo], + setter_fields: &[SetterFieldInfo], ) -> Result { let struct_name = get_struct_name(&input)?; let struct_name_str = struct_name.to_string(); @@ -31,10 +34,9 @@ pub(crate) fn expand( let args_channels_rx_tx = methods.clone().map(|m| &m.args_channels_rx_tx); let select_arms = methods.clone().map(|m| &m.select_arm); - // Generate public setters for published fields with pub_setter. - let pub_setters: Vec<_> = published_fields + // Generate public setters for fields with setter attribute. + let pub_setters: Vec<_> = setter_fields .iter() - .filter(|field| field.pub_setter) .map(|field| generate_pub_setter(field, &struct_name)) .collect(); let pub_setter_channel_declarations = pub_setters.iter().map(|s| &s.channel_declarations); @@ -46,15 +48,31 @@ pub(crate) fn expand( let pub_setter_client_tx_rx_initializations = pub_setters.iter().map(|s| &s.client_tx_rx_initializations); + // Generate public getters for fields with getter attribute. + let pub_getters: Vec<_> = getter_fields + .iter() + .map(|field| generate_pub_getter(field, &struct_name)) + .collect(); + let pub_getter_channel_declarations = pub_getters.iter().map(|g| &g.channel_declarations); + let pub_getter_rx_tx = pub_getters.iter().map(|g| &g.rx_tx); + let pub_getter_select_arms = pub_getters.iter().map(|g| &g.select_arm); + let pub_getter_client_methods = pub_getters.iter().map(|g| &g.client_method); + let pub_getter_client_tx_rx_declarations = + pub_getters.iter().map(|g| &g.client_tx_rx_declarations); + let pub_getter_client_tx_rx_initializations = + pub_getters.iter().map(|g| &g.client_tx_rx_initializations); + let run_method = quote! { pub async fn run(mut self) { #(#args_channels_rx_tx)* #(#pub_setter_rx_tx)* + #(#pub_getter_rx_tx)* loop { futures::select_biased! { #(#select_arms,)* - #(#pub_setter_select_arms),* + #(#pub_setter_select_arms,)* + #(#pub_getter_select_arms,)* } } } @@ -97,12 +115,14 @@ pub(crate) fn expand( Ok(quote! { #(#args_channel_declarations)* #(#pub_setter_channel_declarations)* + #(#pub_getter_channel_declarations)* #input pub struct #client_name { #(#client_method_tx_rx_declarations)* #(#pub_setter_client_tx_rx_declarations)* + #(#pub_getter_client_tx_rx_declarations)* } impl #client_name { @@ -110,6 +130,7 @@ pub(crate) fn expand( Self { #(#client_method_tx_rx_initializations)* #(#pub_setter_client_tx_rx_initializations)* + #(#pub_getter_client_tx_rx_initializations)* } } @@ -117,6 +138,8 @@ pub(crate) fn expand( #(#pub_setter_client_methods)* + #(#pub_getter_client_methods)* + #(#published_field_getters)* #(#signal_getters)* @@ -600,7 +623,17 @@ struct PubSetter { client_tx_rx_initializations: TokenStream, } -fn generate_pub_setter(field: &PublishedFieldInfo, struct_name: &Ident) -> PubSetter { +#[derive(Debug)] +struct PubGetter { + channel_declarations: TokenStream, + rx_tx: TokenStream, + select_arm: TokenStream, + client_method: TokenStream, + client_tx_rx_declarations: TokenStream, + client_tx_rx_initializations: TokenStream, +} + +fn generate_pub_setter(field: &SetterFieldInfo, struct_name: &Ident) -> PubSetter { let field_name = &field.field_name; let field_type = &field.field_type; let setter_method_name = &field.setter_name; @@ -645,13 +678,27 @@ fn generate_pub_setter(field: &PublishedFieldInfo, struct_name: &Ident) -> PubSe let #output_channel_tx_name = embassy_sync::channel::Channel::sender(&#output_channel_name); }; - let select_arm = quote! { - value = futures::FutureExt::fuse( - embassy_sync::channel::Receiver::receive(&#input_channel_rx_name), - ) => { - self.#setter_method_name(value).await; + let select_arm = if let Some(internal_setter) = &field.internal_setter_name { + // Published field: call the internal setter which broadcasts changes. + quote! { + value = futures::FutureExt::fuse( + embassy_sync::channel::Receiver::receive(&#input_channel_rx_name), + ) => { + self.#internal_setter(value).await; + + embassy_sync::channel::Sender::send(&#output_channel_tx_name, ()).await; + } + } + } else { + // Non-published field: set the field directly. + quote! { + value = futures::FutureExt::fuse( + embassy_sync::channel::Receiver::receive(&#input_channel_rx_name), + ) => { + self.#field_name = value; - embassy_sync::channel::Sender::send(&#output_channel_tx_name, ()).await; + embassy_sync::channel::Sender::send(&#output_channel_tx_name, ()).await; + } } }; @@ -678,12 +725,12 @@ fn generate_pub_setter(field: &PublishedFieldInfo, struct_name: &Ident) -> PubSe embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, (), #capacity, - > + >, }; let client_tx_rx_initializations = quote! { #input_channel_tx_name: embassy_sync::channel::Channel::sender(&#input_channel_name), - #output_channel_rx_name: embassy_sync::channel::Channel::receiver(&#output_channel_name) + #output_channel_rx_name: embassy_sync::channel::Channel::receiver(&#output_channel_name), }; PubSetter { @@ -695,3 +742,107 @@ fn generate_pub_setter(field: &PublishedFieldInfo, struct_name: &Ident) -> PubSe client_tx_rx_initializations, } } + +fn generate_pub_getter(field: &GetterFieldInfo, struct_name: &Ident) -> PubGetter { + let field_name = &field.field_name; + let field_type = &field.field_type; + let getter_name = &field.getter_name; + let field_name_str = field_name.to_string(); + + let struct_name_caps = struct_name.to_string().to_uppercase(); + let field_name_caps = field_name_str.to_uppercase(); + let input_channel_name = Ident::new( + &format!("{}_GET_{}_INPUT_CHANNEL", struct_name_caps, field_name_caps), + field_name.span(), + ); + let output_channel_name = Ident::new( + &format!( + "{}_GET_{}_OUTPUT_CHANNEL", + struct_name_caps, field_name_caps + ), + field_name.span(), + ); + let capacity = super::ALL_CHANNEL_CAPACITY; + + let channel_declarations = quote! { + static #input_channel_name: + embassy_sync::channel::Channel< + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + (), + #capacity, + > = embassy_sync::channel::Channel::new(); + static #output_channel_name: + embassy_sync::channel::Channel< + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #field_type, + #capacity, + > = embassy_sync::channel::Channel::new(); + }; + + let input_channel_rx_name = Ident::new( + &format!("{}_get_request_rx", field_name_str), + field_name.span(), + ); + let output_channel_tx_name = Ident::new( + &format!("{}_get_response_tx", field_name_str), + field_name.span(), + ); + let rx_tx = quote! { + let #input_channel_rx_name = embassy_sync::channel::Channel::receiver(&#input_channel_name); + let #output_channel_tx_name = embassy_sync::channel::Channel::sender(&#output_channel_name); + }; + + let select_arm = quote! { + _ = futures::FutureExt::fuse( + embassy_sync::channel::Receiver::receive(&#input_channel_rx_name), + ) => { + let value = core::clone::Clone::clone(&self.#field_name); + + embassy_sync::channel::Sender::send(&#output_channel_tx_name, value).await; + } + }; + + let input_channel_tx_name = Ident::new( + &format!("{}_get_request_tx", field_name_str), + field_name.span(), + ); + let output_channel_rx_name = Ident::new( + &format!("{}_get_response_rx", field_name_str), + field_name.span(), + ); + let client_method = quote! { + pub async fn #getter_name(&self) -> #field_type { + embassy_sync::channel::Sender::send(&self.#input_channel_tx_name, ()).await; + embassy_sync::channel::Receiver::receive(&self.#output_channel_rx_name).await + } + }; + + let client_tx_rx_declarations = quote! { + #input_channel_tx_name: embassy_sync::channel::Sender< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + (), + #capacity, + >, + #output_channel_rx_name: embassy_sync::channel::Receiver< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #field_type, + #capacity, + >, + }; + + let client_tx_rx_initializations = quote! { + #input_channel_tx_name: embassy_sync::channel::Channel::sender(&#input_channel_name), + #output_channel_rx_name: embassy_sync::channel::Channel::receiver(&#output_channel_name), + }; + + PubGetter { + channel_declarations, + rx_tx, + select_arm, + client_method, + client_tx_rx_declarations, + client_tx_rx_initializations, + } +} diff --git a/src/controller/item_struct.rs b/src/controller/item_struct.rs index da929ac..77ea666 100644 --- a/src/controller/item_struct.rs +++ b/src/controller/item_struct.rs @@ -1,98 +1,189 @@ use crate::util::*; use proc_macro2::TokenStream; use quote::quote; -use syn::{spanned::Spanned, Field, Fields, Ident, ItemStruct, Result, Token}; +use syn::{spanned::Spanned, Field, Fields, Ident, ItemStruct, LitStr, Result, Token}; /// Information about a published field, to be used by impl processing. #[derive(Debug, Clone)] pub(crate) struct PublishedFieldInfo { + pub field_name: Ident, + pub subscriber_struct_name: Ident, +} + +/// Information about a field with a getter, to be used by impl processing. +#[derive(Debug, Clone)] +pub(crate) struct GetterFieldInfo { + pub field_name: Ident, + pub field_type: syn::Type, + pub getter_name: Ident, +} + +/// Information about a field with a public setter, to be used by impl processing. +#[derive(Debug, Clone)] +pub(crate) struct SetterFieldInfo { pub field_name: Ident, pub field_type: syn::Type, + /// The public setter method name (client API). pub setter_name: Ident, - pub subscriber_struct_name: Ident, - pub pub_setter: bool, + /// If the field is published, the internal setter name to call. Otherwise None. + pub internal_setter_name: Option, } /// Result of expanding a struct. pub(crate) struct ExpandedStruct { pub tokens: TokenStream, pub published_fields: Vec, + pub getter_fields: Vec, + pub setter_fields: Vec, } pub(crate) fn expand(mut input: ItemStruct) -> Result { let struct_name = &input.ident; - let fields = StructFields::parse(&mut input.fields, struct_name)?; - let field_names = fields.names(); + let struct_fields = StructFields::parse(&mut input.fields, struct_name)?; + let field_names = struct_fields.names().collect::>(); + + // Collect published field info. let ( - publish_channel_declarations, - publisher_fields_declarations, - publisher_fields_initializations, + watch_channel_declarations, + sender_fields_declarations, + sender_fields_initializations, setters, subscriber_declarations, published_fields_info, - ) = fields.published().fold( + ) = struct_fields.published().fold( (quote!(), quote!(), quote!(), quote!(), quote!(), Vec::new()), |( - publish_channels, - publisher_fields_declarations, - publisher_fields_initializations, + watch_channels, + sender_fields_declarations, + sender_fields_initializations, setters, subscribers, mut infos, ), f| { - let (publish_channel, publisher_field, publisher_field_init, setter, subscriber) = ( - &f.publish_channel_declaration, - &f.publisher_field_declaration, - &f.publisher_field_initialization, - &f.setter, - &f.subscriber_declaration, + let published = f.published.as_ref().unwrap(); + let (watch_channel, sender_field, sender_field_init, setter, subscriber) = ( + &published.watch_channel_declaration, + &published.sender_field_declaration, + &published.sender_field_initialization, + &published.setter, + &published.subscriber_declaration, ); - infos.push(f.info.clone()); + infos.push(published.info.clone()); ( - quote! { #publish_channels #publish_channel }, - quote! { #publisher_fields_declarations #publisher_field, }, - quote! { #publisher_fields_initializations #publisher_field_init, }, + quote! { #watch_channels #watch_channel }, + quote! { #sender_fields_declarations #sender_field, }, + quote! { #sender_fields_initializations #sender_field_init, }, quote! { #setters #setter }, quote! { #subscribers #subscriber }, infos, ) }, ); - let fields = fields.raw_fields().collect::>(); + + // Collect getter field info. + let getter_fields_info: Vec = struct_fields + .with_getter() + .map(|f| { + let field_name = f.field.ident.as_ref().unwrap().clone(); + let field_type = f.field.ty.clone(); + let getter_name = f.attrs.getter_name.clone().unwrap(); + GetterFieldInfo { + field_name, + field_type, + getter_name, + } + }) + .collect(); + + // Collect setter field info. + let setter_fields_info: Vec = struct_fields + .with_setter() + .map(|f| { + let field_name = f.field.ident.as_ref().unwrap().clone(); + let field_type = f.field.ty.clone(); + // Use explicit setter name if provided, otherwise default to set_. + let setter_name = + f.attrs.setter_name.clone().unwrap_or_else(|| { + Ident::new(&format!("set_{}", field_name), field_name.span()) + }); + // If published, use the internal setter; otherwise set field directly. + let internal_setter_name = if f.attrs.publish { + Some(Ident::new( + &format!("set_{}", field_name), + field_name.span(), + )) + } else { + None + }; + SetterFieldInfo { + field_name, + field_type, + setter_name, + internal_setter_name, + } + }) + .collect(); + + let fields = struct_fields.raw_fields().collect::>(); let vis = &input.vis; + // Generate initial value sends for Watch channels. + let initial_value_sends = published_fields_info.iter().map(|info| { + let field_name = &info.field_name; + let sender_name = Ident::new(&format!("{}_sender", field_name), field_name.span()); + quote! { + __self.#sender_name.send(core::clone::Clone::clone(&__self.#field_name)); + } + }); + Ok(ExpandedStruct { tokens: quote! { #vis struct #struct_name { #(#fields),*, - #publisher_fields_declarations + #sender_fields_declarations } impl #struct_name { #[allow(clippy::too_many_arguments)] pub fn new(#(#fields),*) -> Self { - Self { + let __self = Self { #(#field_names),*, - #publisher_fields_initializations - } + #sender_fields_initializations + }; + // Send initial values so subscribers can get them immediately. + #(#initial_value_sends)* + __self } #setters } - #publish_channel_declarations + #watch_channel_declarations #subscriber_declarations }, published_fields: published_fields_info, + getter_fields: getter_fields_info, + setter_fields: setter_fields_info, }) } -/// Parsed struct fields, retuned by `parse_struct_fields`. +/// Parsed controller attributes for a field. +#[derive(Debug, Default)] +struct ControllerAttrs { + /// Whether the field has `publish` attribute. + publish: bool, + /// If set, the getter method name (from `getter` or `getter = "name"`). + getter_name: Option, + /// If set, the setter method name (from `setter` or `setter = "name"`). + setter_name: Option, +} + +/// Parsed struct fields. #[derive(Debug)] struct StructFields { fields: Vec, @@ -120,240 +211,245 @@ impl StructFields { /// Names of all the fields. fn names(&self) -> impl Iterator { - // We know the fields are named by the time `self` is constructed. - self.fields - .iter() - .map(|f| f.field().ident.as_ref().unwrap()) + self.fields.iter().map(|f| f.field.ident.as_ref().unwrap()) } /// All raw fields. fn raw_fields(&self) -> impl Iterator { - self.fields.iter().map(StructField::field) + self.fields.iter().map(|f| &f.field) } /// All the published fields. - fn published(&self) -> impl Iterator { - self.fields.iter().filter_map(|field| match field { - StructField::Published(published) => Some(published.as_ref()), - _ => None, - }) + fn published(&self) -> impl Iterator { + self.fields.iter().filter(|f| f.published.is_some()) + } + + /// All fields with getters. + fn with_getter(&self) -> impl Iterator { + self.fields.iter().filter(|f| f.attrs.getter_name.is_some()) + } + + /// All fields with setters (via `setter` attribute). + fn with_setter(&self) -> impl Iterator { + self.fields.iter().filter(|f| f.attrs.setter_name.is_some()) } } -/// struct fields. +/// A struct field with its parsed controller attributes and generated code. #[derive(Debug)] -enum StructField { - /// Private field. - Private(Box), - /// Published field. - Published(Box), +struct StructField { + /// The field with controller attributes removed. + field: Field, + /// Parsed controller attributes. + attrs: ControllerAttrs, + /// Generated publish code (if `publish` attribute is present). + published: Option, } impl StructField { /// Parse a struct field. - fn parse(field: &mut Field, struct_name: &Ident) -> Result { - PublishedField::parse(field, struct_name).map(|published| { - published - .map(|p| StructField::Published(Box::new(p))) - .unwrap_or_else(|| StructField::Private(Box::new(field.clone()))) - }) - } + fn parse(field: &mut Field, struct_name: &Ident) -> Result { + let attrs = parse_controller_attrs(field)?; - /// Get the field. - fn field(&self) -> &Field { - match self { - Self::Private(field) => field.as_ref(), - Self::Published(published) => &published.field, - } + let published = if attrs.publish { + Some(generate_publish_code(field, struct_name)?) + } else { + None + }; + + Ok(Self { + field: field.clone(), + attrs, + published, + }) } } +/// Generated code for a published field. #[derive(Debug)] -/// Published field. -struct PublishedField { - /// Struct fields with the `controller` attributes removed. - field: Field, - /// Publisher field declaration. - publisher_field_declaration: proc_macro2::TokenStream, - /// Publisher field initialization. - publisher_field_initialization: proc_macro2::TokenStream, +struct PublishedFieldCode { + /// Watch sender field declaration. + sender_field_declaration: proc_macro2::TokenStream, + /// Watch sender field initialization. + sender_field_initialization: proc_macro2::TokenStream, /// Field setter. setter: proc_macro2::TokenStream, - /// Publish channel declaration. - publish_channel_declaration: proc_macro2::TokenStream, + /// Watch channel declaration. + watch_channel_declaration: proc_macro2::TokenStream, /// Subscriber struct declaration. subscriber_declaration: proc_macro2::TokenStream, /// Information to be passed to impl processing. info: PublishedFieldInfo, } -impl PublishedField { - /// Parse a struct field. - fn parse(field: &mut Field, struct_name: &Ident) -> Result> { - let attr = match field - .attrs - .iter() - .find(|attr| attr.path().is_ident("controller")) - { - Some(attr) => attr, - None => return Ok(None), - }; - let mut pub_setter = false; - attr.parse_nested_meta(|meta| { - if !meta.path.is_ident("publish") { - let e = format!( - "expected `publish` attribute, found `{}`", - meta.path.get_ident().unwrap() - ); - - return Err(syn::Error::new_spanned(attr, e)); +/// Parse the `#[controller(...)]` attributes from a field. +fn parse_controller_attrs(field: &mut Field) -> Result { + let mut attrs = ControllerAttrs::default(); + + let Some(attr) = field + .attrs + .iter() + .find(|attr| attr.path().is_ident("controller")) + else { + return Ok(attrs); + }; + + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("publish") { + attrs.publish = true; + } else if meta.path.is_ident("getter") { + let field_name = field.ident.as_ref().unwrap(); + if meta.input.peek(Token![=]) { + meta.input.parse::()?; + let name: LitStr = meta.input.parse()?; + attrs.getter_name = Some(Ident::new(&name.value(), name.span())); + } else { + attrs.getter_name = Some(field_name.clone()); + } + } else if meta.path.is_ident("setter") { + let field_name = field.ident.as_ref().unwrap(); + if meta.input.peek(Token![=]) { + meta.input.parse::()?; + let name: LitStr = meta.input.parse()?; + attrs.setter_name = Some(Ident::new(&name.value(), name.span())); + } else { + let default_name = format!("set_{}", field_name); + attrs.setter_name = Some(Ident::new(&default_name, field_name.span())); } + } else { + let ident = meta.path.get_ident().unwrap(); + let e = format!( + "expected `publish`, `getter`, or `setter`, found `{}`", + ident + ); + return Err(syn::Error::new_spanned(ident, e)); + } - if meta.input.peek(syn::token::Paren) { - let content; - syn::parenthesized!(content in meta.input); - while !content.is_empty() { - let nested_ident: Ident = content.parse()?; - if nested_ident == "pub_setter" { - pub_setter = true; - } else { - let e = - format!("expected `pub_setter` attribute, found `{}`", nested_ident); - return Err(syn::Error::new_spanned(&nested_ident, e)); - } + Ok(()) + })?; - if !content.is_empty() { - content.parse::()?; - } - } - } + // Remove controller attributes from the field. + field + .attrs + .retain(|attr| !attr.path().is_ident("controller")); - Ok(()) - })?; - field - .attrs - .retain(|attr| !attr.path().is_ident("controller")); - let struct_name = struct_name.to_string(); - let field_name = field.ident.as_ref().unwrap(); - let field_name_str = field_name.to_string(); - let ty = &field.ty; - - let struct_name_caps = pascal_to_snake_case(&struct_name.to_string()).to_ascii_uppercase(); - let field_name_caps = field_name_str.to_ascii_uppercase(); - let publish_channel_name = Ident::new( - &format!("{struct_name_caps}_{field_name_caps}_CHANNEL"), - field.span(), - ); - - let field_name_pascal = snake_to_pascal_case(&field_name_str); - let subscriber_struct_name = - Ident::new(&format!("{struct_name}{field_name_pascal}"), field.span()); - let change_struct_name = Ident::new( - &format!("{struct_name}{field_name_pascal}Changed"), - field.span(), - ); - let capacity = super::ALL_CHANNEL_CAPACITY; - let max_subscribers = super::BROADCAST_MAX_SUBSCRIBERS; - let max_publishers = super::BROADCAST_MAX_PUBLISHERS; - - let setter_name = Ident::new(&format!("set_{field_name_str}"), field.span()); - let publisher_name = Ident::new(&format!("{field_name_str}_publisher"), field.span()); - let publisher_field_declaration = quote! { - #publisher_name: - embassy_sync::pubsub::Publisher< - 'static, - embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, - #change_struct_name, - #capacity, - #max_subscribers, - #max_publishers, - > - }; - let publisher_field_initialization = quote! { - // We only create one publisher so we can't fail. - #publisher_name: embassy_sync::pubsub::PubSubChannel::publisher(&#publish_channel_name).unwrap() - }; - let setter = quote! { - pub async fn #setter_name(&mut self, mut value: #ty) { - core::mem::swap(&mut self.#field_name, &mut value); - - let change = #change_struct_name { - previous: value, - new: core::clone::Clone::clone(&self.#field_name), - }; - embassy_sync::pubsub::publisher::Pub::publish_immediate( - &self.#publisher_name, - change, - ); - } - }; + Ok(attrs) +} - let publish_channel_declaration = quote! { - static #publish_channel_name: - embassy_sync::pubsub::PubSubChannel< - embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, - #change_struct_name, - #capacity, - #max_subscribers, - #max_publishers, - > = embassy_sync::pubsub::PubSubChannel::new(); - }; +/// Generate code for a published field using Watch channel. +fn generate_publish_code(field: &Field, struct_name: &Ident) -> Result { + let struct_name_str = struct_name.to_string(); + let field_name = field.ident.as_ref().unwrap(); + let field_name_str = field_name.to_string(); + let ty = &field.ty; + + let struct_name_caps = pascal_to_snake_case(&struct_name_str).to_ascii_uppercase(); + let field_name_caps = field_name_str.to_ascii_uppercase(); + let watch_channel_name = Ident::new( + &format!("{struct_name_caps}_{field_name_caps}_WATCH"), + field.span(), + ); - let subscriber_declaration = quote! { - pub struct #subscriber_struct_name { - subscriber: embassy_sync::pubsub::Subscriber< - 'static, - embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, - #change_struct_name, - #capacity, - #max_subscribers, - #max_publishers, - >, - } + let field_name_pascal = snake_to_pascal_case(&field_name_str); + let subscriber_struct_name = Ident::new( + &format!("{struct_name_str}{field_name_pascal}"), + field.span(), + ); + let max_subscribers = super::BROADCAST_MAX_SUBSCRIBERS; + + let setter_name = Ident::new(&format!("set_{field_name_str}"), field.span()); + let sender_name = Ident::new(&format!("{field_name_str}_sender"), field.span()); + + let sender_field_declaration = quote! { + #sender_name: + embassy_sync::watch::Sender< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #ty, + #max_subscribers, + > + }; + + let sender_field_initialization = quote! { + #sender_name: embassy_sync::watch::Watch::sender(&#watch_channel_name) + }; + + // Watch send() is sync, but we keep the setter async for API compatibility. + let setter = quote! { + pub async fn #setter_name(&mut self, value: #ty) { + self.#field_name = value; + self.#sender_name.send(core::clone::Clone::clone(&self.#field_name)); + } + }; + + let watch_channel_declaration = quote! { + static #watch_channel_name: + embassy_sync::watch::Watch< + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #ty, + #max_subscribers, + > = embassy_sync::watch::Watch::new(); + }; + + let subscriber_declaration = quote! { + pub struct #subscriber_struct_name { + receiver: embassy_sync::watch::Receiver< + 'static, + embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex, + #ty, + #max_subscribers, + >, + first_poll: bool, + } - impl #subscriber_struct_name { - pub fn new() -> Option { - embassy_sync::pubsub::PubSubChannel::subscriber(&#publish_channel_name) - .ok() - .map(|subscriber| Self { subscriber }) - } + impl #subscriber_struct_name { + pub fn new() -> Option { + embassy_sync::watch::Watch::receiver(&#watch_channel_name) + .map(|receiver| Self { + receiver, + first_poll: true, + }) } + } - impl futures::Stream for #subscriber_struct_name { - type Item = #change_struct_name; + impl futures::Stream for #subscriber_struct_name { + type Item = #ty; - fn poll_next( - self: core::pin::Pin<&mut Self>, - cx: &mut core::task::Context<'_>, - ) -> core::task::Poll> { - let subscriber = core::pin::Pin::new(&mut *self.get_mut().subscriber); - futures::Stream::poll_next(subscriber, cx) - } - } + fn poll_next( + mut self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + use core::future::Future; - #[derive(Debug, Clone)] - pub struct #change_struct_name { - pub previous: #ty, - pub new: #ty, - } - }; + let this = self.as_mut().get_mut(); - let info = PublishedFieldInfo { - field_name: field_name.clone(), - field_type: ty.clone(), - setter_name: setter_name.clone(), - subscriber_struct_name: subscriber_struct_name.clone(), - pub_setter, - }; + // First poll: return current value immediately if available. + if this.first_poll { + this.first_poll = false; + if let Some(value) = this.receiver.try_get() { + return core::task::Poll::Ready(Some(value)); + } + } - Ok(Some(PublishedField { - field: field.clone(), - publisher_field_declaration, - publisher_field_initialization, - setter, - publish_channel_declaration, - subscriber_declaration, - info, - })) - } + // Create changed() future and poll it in place. + let fut = this.receiver.changed(); + futures::pin_mut!(fut); + fut.poll(cx).map(Some) + } + } + }; + + let info = PublishedFieldInfo { + field_name: field_name.clone(), + subscriber_struct_name, + }; + + Ok(PublishedFieldCode { + sender_field_declaration, + sender_field_initialization, + setter, + watch_channel_declaration, + subscriber_declaration, + info, + }) } diff --git a/src/controller/mod.rs b/src/controller/mod.rs index 5a4583e..f2f076c 100644 --- a/src/controller/mod.rs +++ b/src/controller/mod.rs @@ -74,7 +74,12 @@ pub(crate) fn expand_module(input: ItemMod) -> Result { } let expanded_struct = item_struct::expand(struct_item)?; - let expanded_impl = item_impl::expand(impl_item, &expanded_struct.published_fields)?; + let expanded_impl = item_impl::expand( + impl_item, + &expanded_struct.published_fields, + &expanded_struct.getter_fields, + &expanded_struct.setter_fields, + )?; let struct_tokens = expanded_struct.tokens; Ok(quote! { diff --git a/tests/integration.rs b/tests/integration.rs index 5c5dea3..e90ec08 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -25,10 +25,11 @@ mod test_controller { use super::*; pub struct Controller { - #[controller(publish)] + #[controller(publish, getter = "get_current_state", setter = "change_state")] state: State, - #[controller(publish(pub_setter))] + #[controller(publish, getter, setter)] mode: Mode, + #[controller(setter)] counter: u32, } @@ -93,6 +94,13 @@ fn test_controller_basic_functionality() { // Test 1: Subscribe to state changes. let mut state_stream = client.receive_state_changed().expect("Failed to subscribe"); + // Test 1a: First poll returns the initial (current) value. + let initial_state = state_stream + .next() + .await + .expect("Should receive initial state"); + assert_eq!(initial_state, State::Idle, "Initial state should be Idle"); + // Test 2: Subscribe to signals. let mut error_stream = client .receive_error_occurred() @@ -119,21 +127,12 @@ fn test_controller_basic_functionality() { "Activate should succeed from Idle state" ); - // Verify we received the state change. - let state_change = state_stream + // Verify we received the state change (raw value, not Changed struct). + let new_state = state_stream .next() .await .expect("Should receive state change"); - assert_eq!( - state_change.previous, - State::Idle, - "Previous state should be Idle" - ); - assert_eq!( - state_change.new, - State::Active, - "New state should be Active" - ); + assert_eq!(new_state, State::Active, "New state should be Active"); // Verify we received the operation_complete signal. let _complete = complete_stream @@ -154,16 +153,11 @@ fn test_controller_basic_functionality() { ); // Verify state changed to Error. - let state_change = state_stream + let new_state = state_stream .next() .await .expect("Should receive state change"); - assert_eq!( - state_change.previous, - State::Active, - "Previous state should be Active" - ); - assert_eq!(state_change.new, State::Error, "New state should be Error"); + assert_eq!(new_state, State::Error, "New state should be Error"); // Verify we received the error signal. let error_signal = error_stream @@ -189,12 +183,34 @@ fn test_controller_basic_functionality() { "Should return InvalidState error" ); - // Test 8: Use pub_setter to change mode. + // Test 8: Use setter to change mode. client.set_mode(Mode::Debug).await; // Test 9: Call method with no return value. client.return_nothing().await; + // Test 10: Use getter with custom name to get state. + let state = client.get_current_state().await; + assert_eq!(state, State::Error, "State should be Error"); + + // Test 11: Use getter with default field name to get mode. + let mode = client.mode().await; + assert_eq!(mode, Mode::Debug, "Mode should be Debug"); + + // Test 12: Use setter with custom name (new syntax). + client.change_state(State::Idle).await; + let state = client.get_current_state().await; + assert_eq!( + state, + State::Idle, + "State should be Idle after change_state" + ); + + // Test 13: Use setter without publish (independent setter). + client.set_counter(100).await; + let counter = client.get_counter().await; + assert_eq!(counter, 100, "Counter should be 100 after set_counter"); + // If we get here, all tests passed. }); }