diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7c1ae9c8e..17ccb296b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -192,6 +192,7 @@ if(RAPIDSMPF_HAVE_STREAMING) src/streaming/coll/shuffler.cpp src/streaming/core/channel.cpp src/streaming/core/context.cpp + src/streaming/core/fanout.cpp src/streaming/core/leaf_node.cpp src/streaming/core/node.cpp src/streaming/core/spillable_messages.cpp diff --git a/cpp/include/rapidsmpf/streaming/core/channel.hpp b/cpp/include/rapidsmpf/streaming/core/channel.hpp index 2530d596a..cf1734965 100644 --- a/cpp/include/rapidsmpf/streaming/core/channel.hpp +++ b/cpp/include/rapidsmpf/streaming/core/channel.hpp @@ -46,6 +46,8 @@ class Channel { * @param msg The msg to send. * @return A coroutine that evaluates to true if the msg was successfully sent or * false if the channel was shut down. + * + * @throws std::logic_error If the message is empty. */ coro::task send(Message msg); @@ -56,6 +58,8 @@ class Channel { * * @return A coroutine that evaluates to the message, which will be empty if the * channel is shut down. + * + * @throws std::logic_error If the received message is empty. */ coro::task receive(); @@ -85,6 +89,13 @@ class Channel { */ [[nodiscard]] bool empty() const noexcept; + /** + * @brief Check whether the channel is shut down. + * + * @return True if the channel is shut down. + */ + [[nodiscard]] bool is_shutdown() const noexcept; + private: Channel(std::shared_ptr spillable_messages) : sm_{std::move(spillable_messages)} {} diff --git a/cpp/include/rapidsmpf/streaming/core/context.hpp b/cpp/include/rapidsmpf/streaming/core/context.hpp index cbb01483e..8c4fb914a 100644 --- a/cpp/include/rapidsmpf/streaming/core/context.hpp +++ b/cpp/include/rapidsmpf/streaming/core/context.hpp @@ -78,6 +78,13 @@ class Context { */ [[nodiscard]] std::shared_ptr comm() const noexcept; + /** + * @brief Returns the logger. + * + * @return Reference to the logger. + */ + [[nodiscard]] Communicator::Logger& logger() const noexcept; + /** * @brief Returns the progress thread. * diff --git a/cpp/include/rapidsmpf/streaming/core/fanout.hpp b/cpp/include/rapidsmpf/streaming/core/fanout.hpp new file mode 100644 index 000000000..c15193272 --- /dev/null +++ b/cpp/include/rapidsmpf/streaming/core/fanout.hpp @@ -0,0 +1,67 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include + +namespace rapidsmpf::streaming::node { + +/** + * @brief Fanout policy controlling how messages are propagated. + */ +enum class FanoutPolicy : uint8_t { + /** + * @brief Process messages as they arrive and immediately forward them. + * + * Messages are forwarded as soon as they are received from the input channel. + * The next message is not processed until all output channels have completed + * sending the current one, ensuring backpressure and synchronized flow. + */ + BOUNDED, + + /** + * @brief Forward messages without enforcing backpressure. + * + * In this mode, messages may be accumulated internally before being + * broadcast, or they may be forwarded immediately depending on the + * implementation and downstream consumption rate. + * + * This mode disables coordinated backpressure between outputs, allowing + * consumers to process at independent rates, but can lead to unbounded + * buffering and increased memory usage. + */ + UNBOUNDED, +}; + +/** + * @brief Broadcast messages from one input channel to multiple output channels. + * + * The node continuously receives messages from the input channel and forwards + * them to all output channels according to the selected fanout policy, see + * ::FanoutPolicy. + * + * Each output channel receives a deep copy of the same message. + * + * @param ctx The node context to use. + * @param ch_in Input channel from which messages are received. + * @param chs_out Output channels to which messages are broadcast. Must be at least 2. + * @param policy The fanout strategy to use (see ::FanoutPolicy). + * + * @return Streaming node representing the fanout operation. + * + * @throws std::invalid_argument If an unknown fanout policy is specified or if the number + * of output channels is less than 2. + */ +Node fanout( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out, + FanoutPolicy policy +); + +} // namespace rapidsmpf::streaming::node diff --git a/cpp/src/streaming/core/channel.cpp b/cpp/src/streaming/core/channel.cpp index 79f31eff4..347d41790 100644 --- a/cpp/src/streaming/core/channel.cpp +++ b/cpp/src/streaming/core/channel.cpp @@ -9,14 +9,15 @@ namespace rapidsmpf::streaming { coro::task Channel::send(Message msg) { + RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); auto result = co_await rb_.produce(sm_->insert(std::move(msg))); co_return result == coro::ring_buffer_result::produce::produced; } coro::task Channel::receive() { - auto msg = co_await rb_.consume(); - if (msg.has_value()) { - co_return sm_->extract(*msg); + auto msg_id = co_await rb_.consume(); + if (msg_id.has_value()) { + co_return sm_->extract(*msg_id); } else { co_return Message{}; } @@ -34,4 +35,8 @@ bool Channel::empty() const noexcept { return rb_.empty(); } +bool Channel::is_shutdown() const noexcept { + return rb_.is_shutdown(); +} + } // namespace rapidsmpf::streaming diff --git a/cpp/src/streaming/core/context.cpp b/cpp/src/streaming/core/context.cpp index 97a399242..1714ad011 100644 --- a/cpp/src/streaming/core/context.cpp +++ b/cpp/src/streaming/core/context.cpp @@ -126,6 +126,10 @@ std::shared_ptr Context::comm() const noexcept { return comm_; } +Communicator::Logger& Context::logger() const noexcept { + return comm_->logger(); +} + std::shared_ptr Context::progress_thread() const noexcept { return progress_thread_; } diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp new file mode 100644 index 000000000..1b9631daa --- /dev/null +++ b/cpp/src/streaming/core/fanout.cpp @@ -0,0 +1,450 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace rapidsmpf::streaming::node { +namespace { + +/** + * @brief Try to allocate memory from the memory types that the message content uses. + * + * @param msg The message to allocate memory for. + * @return The memory types to try to allocate from. + */ +constexpr std::span try_memory_types(Message const& msg) { + auto const& cd = msg.content_description(); + // if the message content uses device memory, try to allocate from device memory + // first, else allocate from host memory + return cd.content_size(MemoryType::DEVICE) > 0 + ? MEMORY_TYPES + : std::span{ + MEMORY_TYPES.begin() + static_cast(MemoryType::HOST), + MEMORY_TYPES.end() + }; +} + +/** + * @brief Asynchronously send a message to multiple output channels. + * + * @param msg The message to broadcast. Each channel receives a deep copy of the original + * message. + * @param chs_out The set of output channels to which the message is sent. + */ +Node send_to_channels( + Context& ctx, Message&& msg, std::vector>& chs_out +) { + RAPIDSMPF_EXPECTS(!chs_out.empty(), "output channels cannot be empty"); + + auto async_copy_and_send = [](Context& ctx_, + Message const& msg_, + size_t msg_sz_, + Channel& ch_) -> coro::task { + co_await ctx_.executor()->schedule(); + auto res = ctx_.br()->reserve_or_fail(msg_sz_, try_memory_types(msg_)); + co_return co_await ch_.send(msg_.copy(res)); + }; + + // async copy & send tasks for all channels except the last one + std::vector> async_send_tasks; + async_send_tasks.reserve(chs_out.size() - 1); + size_t msg_sz = msg.copy_cost(); + for (size_t i = 0; i < chs_out.size() - 1; i++) { + async_send_tasks.emplace_back(async_copy_and_send(ctx, msg, msg_sz, *chs_out[i])); + } + + // note that the send tasks may return false if the channel is shut down. But we can + // safely ignore this in bounded fanout. + coro_results(co_await coro::when_all(std::move(async_send_tasks))); + + // move the message to the last channel to avoid extra copy + co_await chs_out.back()->send(std::move(msg)); +} + +/** + * @brief Broadcast messages from one input channel to multiple output channels. + * + * @note Bounded fanout requires all the output channels to consume messages before + * the next message is sent/consumed from the input channel. + * + * @param ctx The context to use. + * @param ch_in The input channel to receive messages from. + * @param chs_out The output channels to send messages to. + * @return A node representing the bounded fanout operation. + */ +Node bounded_fanout( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out +) { + ShutdownAtExit c1{ch_in}; + ShutdownAtExit c2{chs_out}; + co_await ctx->executor()->schedule(); + while (true) { + auto msg = co_await ch_in->receive(); + if (msg.empty()) { + break; + } + + // filter out shut down channels to avoid making unnecessary copies + std::erase_if(chs_out, [](auto&& ch) { return ch->is_shutdown(); }); + if (chs_out.empty()) { + // all channels are shut down, so we can break & shutdown the input channel + break; + } + co_await send_to_channels(*ctx, std::move(msg), chs_out); + } + + std::vector drain_tasks; + drain_tasks.reserve(chs_out.size()); + for (auto& ch : chs_out) { + drain_tasks.emplace_back(ch->drain(ctx->executor())); + } + coro_results(co_await coro::when_all(std::move(drain_tasks))); +} + +/** + * @brief Unbounded fanout implementation. + * + * The implementation follows a pull-based model, where the send tasks request data from + * the recv task. There is one recv task that receives messages from the input channel, + * and there are N send tasks that send messages to the output channels. + * + * Main task operation: + * - There is a shared deque of cached messages, and a vector that indicates the next + * index of the message to be sent to each output channel. + * - All shared resources are protected by a mutex. There are two condition variables + * where: + * - recv task notifies send tasks when new messages are cached + * - send tasks notify recv task when they have completed sending messages + * - Recv task awaits until the number of cached messages at least one send task has + * completed sending all the cached messages. It will then pull a message from the input + * channel, cache it, and notify the send tasks about the new messages. recv task + * continues this process until the input channel is fully consumed. + * - Each send task awaits until there are more cached messages to send. When the new + * messages available noitification is received, it will continue to copy and send cached + * messages, starting from the index of the last sent message, to the end of the cached + * messages (as it last observed). Then it updates the last completed message index and + * notifies the recv task. This process continues until the recv task notifies that the + * input channel is fully consumed. + * + * Additional considerations: + * - In the recv task loop, it also identifies the lowest completed message index by all + * send tasks. Message upto this index are no longer needed, and are released from the + * cached messages deque. + * - When a send task fails to send a message, this means the channel may have been + * prematurely shut down. In this case, it sets a sential value to mark it as invalid. + * Recv task will filter out channels with the invalid sentinel value. + * - There are two RAII helpers to ensure that the notification mechanisms are properly + * cleaned up when the unbounded fanout state goes out of scope/ encounters an error. + * + */ +struct UnboundedFanout { + /** + * @brief Constructor. + * + * @param num_channels The number of output channels. + */ + explicit UnboundedFanout(size_t num_channels) : per_ch_processed(num_channels, 0) {} + + /** + * @brief Sentinel value indicating that the index is invalid. This is set when a + * failure occurs during send tasks. process input task will filter out messages with + * this index. + */ + static constexpr size_t InvalidIdx = std::numeric_limits::max(); + + /** + * @brief RAII helper class to set a channel index to invalid and notify the process + * input task to check if it should break. + */ + struct SetChannelIdxInvalidAtExit { + UnboundedFanout* fanout; + size_t& self_next_idx; + + ~SetChannelIdxInvalidAtExit() { + coro::sync_wait(set_channel_idx_invalid()); + } + + Node set_channel_idx_invalid() { + if (self_next_idx != InvalidIdx) { + { + auto lock = co_await fanout->mtx.scoped_lock(); + self_next_idx = InvalidIdx; + } + co_await fanout->request_data.notify_one(); + } + } + }; + + /** + * @brief Send messages to multiple output channels. + * + * @param ctx The context to use. + * @param self_next_idx Next index to send for the current channel + * @param ch_out The output channel to send messages to. + * @return A coroutine representing the task. + */ + Node send_task(Context& ctx, size_t& self_next_idx, std::shared_ptr ch_out) { + ShutdownAtExit ch_shutdown{ch_out}; + SetChannelIdxInvalidAtExit set_ch_idx_invalid{ + .fanout = this, .self_next_idx = self_next_idx + }; + co_await ctx.executor()->schedule(); + + size_t n_available_messages = 0; + std::vector> messages_to_send; + while (true) { + { + auto lock = co_await mtx.scoped_lock(); + co_await data_ready.wait(lock, [&] { + // irrespective of no_more_input, update the end_idx to the total + // number of messages + n_available_messages = recv_messages.size(); + return no_more_input || self_next_idx < n_available_messages; + }); + if (no_more_input && self_next_idx == n_available_messages) { + // no more messages will be received, and all messages have been sent + break; + } + // stash msg references under the lock + messages_to_send.reserve(n_available_messages - self_next_idx); + for (size_t i = self_next_idx; i < n_available_messages; i++) { + messages_to_send.emplace_back(recv_messages[i]); + } + } + + for (auto const& msg : messages_to_send) { + RAPIDSMPF_EXPECTS(!msg.get().empty(), "message cannot be empty"); + + auto res = ctx.br()->reserve_or_fail( + msg.get().copy_cost(), try_memory_types(msg.get()) + ); + if (!co_await ch_out->send(msg.get().copy(res))) { + // Failed to send message. Could be that the channel is shut down. + // So we need to abort the send task, and notify the process input + // task + co_await set_ch_idx_invalid.set_channel_idx_invalid(); + co_return; + } + } + messages_to_send.clear(); + + // now next_idx can be updated to end_idx, and if !no_more_input, we need to + // request the recv task for more data + auto lock = co_await mtx.scoped_lock(); + self_next_idx = n_available_messages; + if (self_next_idx == recv_messages.size()) { + if (no_more_input) { + // no more messages will be received, and all messages have been sent + break; + } else { + // request more data from the input channel + lock.unlock(); + co_await request_data.notify_one(); + } + } + } + co_await ch_out->drain(ctx.executor()); + } + + /** + * @brief RAII helper class to set no_more_input to true and notify all send tasks to + * wind down when the unbounded fanout state goes out of scope. + */ + struct SetInputDoneAtExit { + UnboundedFanout* fanout; + + ~SetInputDoneAtExit() { + coro::sync_wait(set_input_done()); + } + + // forcibly set no_more_input to true and notify all send tasks to wind down + Node set_input_done() { + { + auto lock = co_await fanout->mtx.scoped_lock(); + fanout->no_more_input = true; + } + co_await fanout->data_ready.notify_all(); + } + }; + + /** + * @brief Wait for a data request from the send tasks. + * + * @return A minmax pair of `per_ch_processed` values. min is index of the last + * completed message index + 1 and max is the index of the latest processed message + * index + 1. If both are InvalidIdx, it means that all send tasks are in an invalid + * state. + */ + auto wait_for_data_request() -> coro::task> { + size_t per_ch_processed_min = InvalidIdx; + size_t per_ch_processed_max = InvalidIdx; + + auto lock = co_await mtx.scoped_lock(); + co_await request_data.wait(lock, [&] { + auto filtered_view = std::ranges::filter_view( + per_ch_processed, [](size_t idx) { return idx != InvalidIdx; } + ); + + auto it = std::ranges::begin(filtered_view); // advance to first valid idx + auto end = std::ranges::end(filtered_view); + if (it == end) { + // no valid indices, so all send tasks are in an invalid state + return true; + } + + auto [min_it, max_it] = std::minmax_element(it, end); + per_ch_processed_min = *min_it; + per_ch_processed_max = *max_it; + + return per_ch_processed_max == recv_messages.size(); + }); + + co_return std::make_pair(per_ch_processed_min, per_ch_processed_max); + } + + /** + * @brief Process input messages and notify send tasks to copy & send messages. + * + * @param ctx The context to use. + * @param ch_in The input channel to receive messages from. + * @return A coroutine representing the task. + */ + Node recv_task(Context& ctx, std::shared_ptr ch_in) { + ShutdownAtExit ch_in_shutdown{ch_in}; + SetInputDoneAtExit set_input_done{.fanout = this}; + co_await ctx.executor()->schedule(); + + // index of the first message to purge + size_t purge_idx = 0; + + // no_more_input is only set by this task, so reading without lock is safe here + while (!no_more_input) { + auto [per_ch_processed_min, per_ch_processed_max] = + co_await wait_for_data_request(); + if (per_ch_processed_min == InvalidIdx && per_ch_processed_max == InvalidIdx) + { + break; + } + + // receive a message from the input channel + auto msg = co_await ch_in->receive(); + + { + auto lock = co_await mtx.scoped_lock(); + if (msg.empty()) { + no_more_input = true; + } else { + recv_messages.emplace_back(std::move(msg)); + } + } + + // notify send_tasks to copy & send messages + co_await data_ready.notify_all(); + + // Reset messages that are no longer needed, so that they release the memory. + // However the deque is not resized. This guarantees that the indices are not + // invalidated. + while (purge_idx < per_ch_processed_min) { + recv_messages[purge_idx].reset(); + purge_idx++; + } + } + + co_await ch_in->drain(ctx.executor()); + } + + coro::mutex mtx; + + /// @brief recv task notifies send tasks to copy & send messages + coro::condition_variable data_ready; + + /// @brief send tasks notify recv task to pull more data from the input channel + coro::condition_variable request_data; + + /// @brief set to true when the input channel is fully consumed + bool no_more_input{false}; + + /// @brief messages received from the input channel. Using a deque to avoid + /// invalidating references by reallocations. + std::deque recv_messages; + + /// @brief number of messages processed for each channel (ie. next index to send for + /// each channel) + std::vector per_ch_processed; +}; + +/** + * @brief Broadcast messages from one input channel to multiple output channels. + * + * In contrast to `bounded_fanout`, an unbounded fanout supports arbitrary + * consumption orders of the output channels. + * + * @param ctx The context to use. + * @param ch_in The input channel to receive messages from. + * @param chs_out The output channels to send messages to. + * @return A node representing the unbounded fanout operation. + */ +Node unbounded_fanout( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out +) { + auto& executor = *ctx->executor(); + + ShutdownAtExit ch_in_shutdown{ch_in}; + ShutdownAtExit chs_out_shutdown{chs_out}; + co_await ctx->executor()->schedule(); + UnboundedFanout fanout(chs_out.size()); + + std::vector tasks; + tasks.reserve(chs_out.size() + 1); + + for (size_t i = 0; i < chs_out.size(); i++) { + tasks.emplace_back(executor.schedule( + fanout.send_task(*ctx, fanout.per_ch_processed[i], std::move(chs_out[i])) + )); + } + tasks.emplace_back(executor.schedule(fanout.recv_task(*ctx, std::move(ch_in)))); + + coro_results(co_await coro::when_all(std::move(tasks))); +} + +} // namespace + +Node fanout( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out, + FanoutPolicy policy +) { + RAPIDSMPF_EXPECTS( + chs_out.size() > 1, + "fanout requires at least 2 output channels", + std::invalid_argument + ); + + switch (policy) { + case FanoutPolicy::BOUNDED: + return bounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); + case FanoutPolicy::UNBOUNDED: + return unbounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); + default: + RAPIDSMPF_FAIL("Unknown broadcast policy", std::invalid_argument); + } +} + +} // namespace rapidsmpf::streaming::node diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 35041f8d6..1be272ceb 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -73,6 +73,7 @@ if(RAPIDSMPF_HAVE_STREAMING) test_sources INTERFACE streaming/test_allgather.cpp streaming/test_error_handling.cpp + streaming/test_fanout.cpp streaming/test_leaf_node.cpp streaming/test_lineariser.cpp streaming/test_message.cpp diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp new file mode 100644 index 000000000..74a2420be --- /dev/null +++ b/cpp/tests/streaming/test_fanout.cpp @@ -0,0 +1,449 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include "base_streaming_fixture.hpp" + +#include + +using namespace rapidsmpf; +using namespace rapidsmpf::streaming; +namespace node = rapidsmpf::streaming::node; +using rapidsmpf::streaming::node::FanoutPolicy; + +/** + * @brief Helper to make a sequence of Message with values [0, n). + */ +std::vector make_int_inputs(int n) { + std::vector inputs; + inputs.reserve(n); + + Message::CopyCallback copy_cb = [](Message const& msg, MemoryReservation&) { + return Message{ + msg.sequence_number(), + std::make_unique(msg.get()), + ContentDescription{}, + msg.copy_cb() + }; + }; + + for (int i = 0; i < n; ++i) { + inputs.emplace_back(i, std::make_unique(i), ContentDescription{}, copy_cb); + } + return inputs; +} + +std::string policy_to_string(FanoutPolicy policy) { + switch (policy) { + case FanoutPolicy::BOUNDED: + return "bounded"; + case FanoutPolicy::UNBOUNDED: + return "unbounded"; + default: + return "unknown"; + } +} + +using BaseStreamingFanout = BaseStreamingFixture; + +TEST_F(BaseStreamingFanout, InvalidNumberOfOutputChannels) { + auto in = ctx->create_channel(); + std::vector> out_chs; + out_chs.push_back(ctx->create_channel()); + EXPECT_THROW( + std::ignore = node::fanout(ctx, in, out_chs, FanoutPolicy::BOUNDED), + std::invalid_argument + ); +} + +class StreamingFanout + : public BaseStreamingFixture, + public ::testing::WithParamInterface> { + public: + void SetUp() override { + std::tie(policy, num_threads, num_out_chs, num_msgs) = GetParam(); + SetUpWithThreads(num_threads); + + // restrict fanout tests to single communicator mode to reduce test runtime + if (GlobalEnvironment->type() != TestEnvironmentType::SINGLE) { + GTEST_SKIP() << "Skipping test in non-single communicator mode"; + } + } + + FanoutPolicy policy; + int num_threads; + int num_out_chs; + int num_msgs; +}; + +INSTANTIATE_TEST_SUITE_P( + StreamingFanout, + StreamingFanout, + ::testing::Combine( + ::testing::Values(FanoutPolicy::BOUNDED, FanoutPolicy::UNBOUNDED), + ::testing::Values(1, 4), // number of threads + ::testing::Values(2, 4), // number of output channels + ::testing::Values(10, 100) // number of messages + ), + [](testing::TestParamInfo const& info) { + return "policy_" + policy_to_string(std::get<0>(info.param)) + "_nthreads_" + + std::to_string(std::get<1>(info.param)) + "_nch_out_" + + std::to_string(std::get<2>(info.param)) + "_nmsgs_" + + std::to_string(std::get<3>(info.param)); + } +); + +TEST_P(StreamingFanout, SinkPerChannel) { + auto inputs = make_int_inputs(num_msgs); + + std::vector> outs(num_out_chs); + { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + for (int i = 0; i < num_out_chs; ++i) { + nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], outs[i])); + } + + run_streaming_pipeline(std::move(nodes)); + } + + for (int c = 0; c < num_out_chs; ++c) { + // Validate sizes + EXPECT_EQ(outs[c].size(), static_cast(num_msgs)); + + // Validate ordering/content and that shallow copies share the same underlying + // object + for (int i = 0; i < num_msgs; ++i) { + SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); + EXPECT_EQ(outs[c][i].get(), i); + } + } +} + +namespace { + +/** + * @brief A node that pulls and shuts down a channel after a certain number of messages + * have been received. + * + * @param ctx The context to use. + * @param ch_in The input channel to receive messages from. + * @param out_messages The output messages to store the received messages in. + * @param max_messages The maximum number of messages to receive. + * @return A coroutine representing the task. + */ +Node shutdown_channel_after_n_messages( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector& out_messages, + size_t max_messages +) { + ShutdownAtExit c{ch_in}; + co_await ctx->executor()->schedule(); + + for (size_t i = 0; i < max_messages; ++i) { + auto msg = co_await ch_in->receive(); + if (msg.empty()) { + break; + } + out_messages.push_back(std::move(msg)); + } + co_await ch_in->shutdown(); +} + +Node throwing_node(std::shared_ptr ctx, std::shared_ptr ch_out) { + ShutdownAtExit c{ch_out}; + co_await ctx->executor()->schedule(); + throw std::logic_error("throwing source"); +} + +} // namespace + +// all channels shutsdown after receiving num_msgs / 2 messages +TEST_P(StreamingFanout, SinkPerChannel_ShutdownHalfWay) { + auto inputs = make_int_inputs(num_msgs); + + std::vector> outs(num_out_chs); + { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + for (int i = 0; i < num_out_chs; ++i) { + nodes.emplace_back( + shutdown_channel_after_n_messages(ctx, out_chs[i], outs[i], num_msgs / 2) + ); + } + + run_streaming_pipeline(std::move(nodes)); + } + + for (int c = 0; c < num_out_chs; ++c) { + EXPECT_EQ(static_cast(num_msgs / 2), outs[c].size()); + + for (int i = 0; i < num_msgs / 2; ++i) { + SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); + EXPECT_EQ(outs[c][i].get(), i); + } + } +} + +// only odd channels shutdown after receiving num_msgs / 2 messages, others continue to +// receive all messages +TEST_P(StreamingFanout, SinkPerChannel_OddChannelsShutdownHalfWay) { + auto inputs = make_int_inputs(num_msgs); + + std::vector> outs(num_out_chs); + { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + for (int i = 0; i < num_out_chs; ++i) { + if (i % 2 == 0) { + nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], outs[i])); + } else { + nodes.emplace_back(shutdown_channel_after_n_messages( + ctx, out_chs[i], outs[i], num_msgs / 2 + )); + } + } + + run_streaming_pipeline(std::move(nodes)); + } + + for (int c = 0; c < num_out_chs; ++c) { + int expected_size = c % 2 == 0 ? num_msgs : num_msgs / 2; + EXPECT_EQ(outs[c].size(), expected_size); + + for (int i = 0; i < expected_size; ++i) { + SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); + EXPECT_EQ(outs[c][i].get(), i); + } + } +} + +class ThrowingStreamingFanout : public StreamingFanout {}; + +INSTANTIATE_TEST_SUITE_P( + ThrowingStreamingFanout, + ThrowingStreamingFanout, + ::testing::Combine( + ::testing::Values(FanoutPolicy::BOUNDED, FanoutPolicy::UNBOUNDED), + ::testing::Values(1, 4), // number of threads + ::testing::Values(4), // number of output channels + ::testing::Values(10) // number of messages + ), + [](testing::TestParamInfo const& info) { + return "policy_" + policy_to_string(std::get<0>(info.param)) + "_nthreads_" + + std::to_string(std::get<1>(info.param)) + "_nch_out_" + + std::to_string(std::get<2>(info.param)) + "_nmsgs_" + + std::to_string(std::get<3>(info.param)); + } +); + +// tests that throwing a source node propagates the error to the pipeline. This test will +// throw, but it should not hang. +TEST_P(ThrowingStreamingFanout, ThrowingSource) { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.emplace_back(throwing_node(ctx, in)); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + std::vector dummy_out; + for (int i = 0; i < num_out_chs; ++i) { + nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], dummy_out)); + } + + EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::logic_error); +} + +// tests that throwing a sink node propagates the error to the pipeline. This test +// will throw, but it should not hang. +TEST_P(ThrowingStreamingFanout, ThrowingSink) { + auto inputs = make_int_inputs(num_msgs); + + std::vector nodes; + auto in = ctx->create_channel(); + nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + std::vector> dummy_outs(num_out_chs); + for (int i = 0; i < num_out_chs; ++i) { + if (i == 0) { + nodes.emplace_back(throwing_node(ctx, out_chs[i])); + } else { + nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], dummy_outs[i])); + } + } + + EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::logic_error); +} + +namespace { +enum class ConsumePolicy : uint8_t { + CHANNEL_ORDER, // consume all messages from a single channel before moving to the + // next + MESSAGE_ORDER, // consume messages from all channels before moving to the next + // message +}; + +Node many_input_sink( + std::shared_ptr ctx, + std::vector> chs, + ConsumePolicy consume_policy, + std::vector>& outs +) { + ShutdownAtExit c{chs}; + co_await ctx->executor()->schedule(); + + if (consume_policy == ConsumePolicy::CHANNEL_ORDER) { + for (size_t i = 0; i < chs.size(); ++i) { + while (true) { + auto msg = co_await chs[i]->receive(); + if (msg.empty()) { + break; + } + outs[i].push_back(std::move(msg)); + } + } + } else if (consume_policy == ConsumePolicy::MESSAGE_ORDER) { + std::unordered_set active_chs{}; + for (size_t i = 0; i < chs.size(); ++i) { + active_chs.insert(i); + } + while (!active_chs.empty()) { + for (auto it = active_chs.begin(); it != active_chs.end();) { + auto msg = co_await chs[*it]->receive(); + if (msg.empty()) { + it = active_chs.erase(it); + } else { + outs[*it].emplace_back(std::move(msg)); + it++; + } + } + } + } +} +} // namespace + +struct ManyInputSinkStreamingFanout : public StreamingFanout { + void run(ConsumePolicy consume_policy) { + auto inputs = make_int_inputs(num_msgs); + + std::vector> outs(num_out_chs); + { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.push_back(node::fanout(ctx, in, out_chs, policy)); + + nodes.push_back(many_input_sink(ctx, out_chs, consume_policy, outs)); + + run_streaming_pipeline(std::move(nodes)); + } + + std::vector expected(num_msgs); + std::iota(expected.begin(), expected.end(), 0); + for (int c = 0; c < num_out_chs; ++c) { + SCOPED_TRACE("channel " + std::to_string(c)); + std::vector actual; + actual.reserve(outs[c].size()); + std::ranges::transform( + outs[c], std::back_inserter(actual), [](const Message& m) { + return m.get(); + } + ); + EXPECT_EQ(expected, actual); + } + } +}; + +INSTANTIATE_TEST_SUITE_P( + ManyInputSinkStreamingFanout, + ManyInputSinkStreamingFanout, + ::testing::Combine( + ::testing::Values(FanoutPolicy::BOUNDED, FanoutPolicy::UNBOUNDED), + ::testing::Values(1, 4), // number of threads + ::testing::Values(2, 4), // number of output channels + ::testing::Values(10, 100) // number of messages + ), + [](testing::TestParamInfo const& info) { + return "policy_" + policy_to_string(std::get<0>(info.param)) + "_nthreads_" + + std::to_string(std::get<1>(info.param)) + "_nch_out_" + + std::to_string(std::get<2>(info.param)) + "_nmsgs_" + + std::to_string(std::get<3>(info.param)); + } +); + +TEST_P(ManyInputSinkStreamingFanout, ChannelOrder) { + if (policy == FanoutPolicy::BOUNDED) { + GTEST_SKIP() << "Bounded fanout does not support channel order"; + } + + EXPECT_NO_FATAL_FAILURE(run(ConsumePolicy::CHANNEL_ORDER)); +} + +TEST_P(ManyInputSinkStreamingFanout, MessageOrder) { + EXPECT_NO_FATAL_FAILURE(run(ConsumePolicy::MESSAGE_ORDER)); +} diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/CMakeLists.txt b/python/rapidsmpf/rapidsmpf/streaming/core/CMakeLists.txt index 435da5669..b36b377e5 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/core/CMakeLists.txt +++ b/python/rapidsmpf/rapidsmpf/streaming/core/CMakeLists.txt @@ -5,8 +5,8 @@ # cmake-format: on # ================================================================================= -set(cython_modules channel.pyx context.pyx leaf_node.pyx message.pyx node.pyx - spillable_messages.pyx utilities.pyx +set(cython_modules channel.pyx context.pyx fanout.pyx leaf_node.pyx message.pyx + spillable_messages.pyx node.pyx utilities.pyx ) rapids_cython_create_modules( diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd new file mode 100644 index 000000000..813d17617 --- /dev/null +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from libc.stdint cimport uint8_t +from libcpp.memory cimport shared_ptr +from libcpp.vector cimport vector + +from rapidsmpf.streaming.core.channel cimport cpp_Channel +from rapidsmpf.streaming.core.context cimport cpp_Context +from rapidsmpf.streaming.core.node cimport cpp_Node + + +cdef extern from "" \ + namespace "rapidsmpf::streaming::node" nogil: + cpdef enum class FanoutPolicy (uint8_t): + BOUNDED + UNBOUNDED + + cdef cpp_Node cpp_fanout \ + "rapidsmpf::streaming::node::fanout"( + shared_ptr[cpp_Context] ctx, + shared_ptr[cpp_Channel] ch_in, + vector[shared_ptr[cpp_Channel]] chs_out, + FanoutPolicy policy + ) except + diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyi b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyi new file mode 100644 index 000000000..8f8864620 --- /dev/null +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyi @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from enum import IntEnum + +from rapidsmpf.streaming.core.channel import Channel +from rapidsmpf.streaming.core.context import Context +from rapidsmpf.streaming.core.node import CppNode + +class FanoutPolicy(IntEnum): + BOUNDED = ... + UNBOUNDED = ... + +def fanout( + ctx: Context, + ch_in: Channel, + chs_out: list[Channel], + policy: FanoutPolicy, +) -> CppNode: ... diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx new file mode 100644 index 000000000..a5edaef5b --- /dev/null +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from libcpp.memory cimport make_unique +from libcpp.utility cimport move +from libcpp.vector cimport vector + +from rapidsmpf.streaming.core.channel cimport Channel +from rapidsmpf.streaming.core.context cimport Context +from rapidsmpf.streaming.core.fanout cimport FanoutPolicy +from rapidsmpf.streaming.core.node cimport CppNode, cpp_Node + + +def fanout(Context ctx, Channel ch_in, chs_out, FanoutPolicy policy): + """ + Broadcast messages from one input channel to multiple output channels. + + The node continuously receives messages from the input channel and forwards + them to all output channels according to the selected fanout policy. + + Each output channel receives a shallow copy of the same message; no payload + data is duplicated. All copies share the same underlying payload, ensuring + zero-copy broadcast semantics. + + Parameters + ---------- + ctx + The node context to use. + ch_in + Input channel from which messages are received. + chs_out + Output channels to which messages are broadcast. + policy + The fanout policy to use. `FanoutPolicy.BOUNDED` can be used if all + output channels are being consumed by independent consumers in the + downstream. `FanoutPolicy.UNBOUNDED` can be used if the output channels + are being consumed by a single/ shared consumer in the downstream. + Returns + ------- + Streaming node representing the fanout operation. + + Raises + ------ + ValueError + If an unknown fanout policy is specified. + + Notes + ----- + Since messages are shallow-copied, releasing a payload (``release()``) + is only valid on messages that hold exclusive ownership of the payload. + + Examples + -------- + >>> import rapidsmpf.streaming.core as streaming + >>> ctx = streaming.Context(...) + >>> ch_in = ctx.create_channel() + >>> ch_out1 = ctx.create_channel() + >>> ch_out2 = ctx.create_channel() + >>> node = streaming.fanout( + ... ctx, ch_in, [ch_out1, ch_out2], streaming.FanoutPolicy.BOUNDED + ... ) + """ + cdef vector[shared_ptr[cpp_Channel]] _chs_out + if len(chs_out) == 0: + raise ValueError("output channels cannot be empty") + owner = [] + for ch_out in chs_out: + if not isinstance(ch_out, Channel): + raise TypeError("All elements in chs_out must be Channel instances") + owner.append(ch_out) + _chs_out.push_back((ch_out)._handle) + + cdef cpp_Node _ret + with nogil: + _ret = cpp_fanout( + ctx._handle, ch_in._handle, move(_chs_out), policy + ) + return CppNode.from_handle(make_unique[cpp_Node](move(_ret)), owner) diff --git a/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py b/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py new file mode 100644 index 000000000..8fe54bb9c --- /dev/null +++ b/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for streaming fanout node.""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING + +import pytest + +import cudf + +from rapidsmpf.streaming.core.fanout import FanoutPolicy, fanout +from rapidsmpf.streaming.core.leaf_node import pull_from_channel, push_to_channel +from rapidsmpf.streaming.core.message import Message +from rapidsmpf.streaming.core.node import run_streaming_pipeline +from rapidsmpf.streaming.cudf.table_chunk import TableChunk +from rapidsmpf.testing import assert_eq +from rapidsmpf.utils.cudf import cudf_to_pylibcudf_table + +if TYPE_CHECKING: + from rmm.pylibrmm.stream import Stream + + from rapidsmpf.streaming.core.channel import Channel + from rapidsmpf.streaming.core.context import Context + + +@pytest.mark.parametrize("policy", [FanoutPolicy.BOUNDED, FanoutPolicy.UNBOUNDED]) +def test_fanout_basic(context: Context, stream: Stream, policy: FanoutPolicy) -> None: + """Test basic fanout functionality with multiple output channels.""" + # Create channels + ch_in: Channel[TableChunk] = context.create_channel() + ch_out1: Channel[TableChunk] = context.create_channel() + ch_out2: Channel[TableChunk] = context.create_channel() + + # Create test messages + messages = [] + for i in range(5): + df = cudf.DataFrame( + {"a": [i, i + 1, i + 2], "b": [i * 10, i * 10 + 1, i * 10 + 2]} + ) + chunk = TableChunk.from_pylibcudf_table( + cudf_to_pylibcudf_table(df), stream, exclusive_view=False + ) + messages.append(Message(i, chunk)) + + # Create nodes + push_node = push_to_channel(context, ch_in, messages) + fanout_node = fanout(context, ch_in, [ch_out1, ch_out2], policy) + pull_node1, output1 = pull_from_channel(context, ch_out1) + pull_node2, output2 = pull_from_channel(context, ch_out2) + + # Run pipeline + with ThreadPoolExecutor(max_workers=1) as executor: + run_streaming_pipeline( + nodes=[push_node, fanout_node, pull_node1, pull_node2], + py_executor=executor, + ) + + # Verify results + results1 = output1.release() + results2 = output2.release() + + assert len(results1) == 5, f"Expected 5 messages in output1, got {len(results1)}" + assert len(results2) == 5, f"Expected 5 messages in output2, got {len(results2)}" + + # Check that both outputs received the same sequence numbers and data + for i in range(5): + assert results1[i].sequence_number == i + assert results2[i].sequence_number == i + + chunk1 = TableChunk.from_message(results1[i]) + chunk2 = TableChunk.from_message(results2[i]) + + # Expected data + expected_df = cudf.DataFrame( + {"a": [i, i + 1, i + 2], "b": [i * 10, i * 10 + 1, i * 10 + 2]} + ) + expected_table = cudf_to_pylibcudf_table(expected_df) + + # Verify data is correct + assert_eq(chunk1.table_view(), expected_table) + assert_eq(chunk2.table_view(), expected_table) + + +@pytest.mark.parametrize("num_outputs", [1, 3, 5]) +@pytest.mark.parametrize("policy", [FanoutPolicy.BOUNDED, FanoutPolicy.UNBOUNDED]) +def test_fanout_multiple_outputs( + context: Context, stream: Stream, num_outputs: int, policy: FanoutPolicy +) -> None: + """Test fanout with varying numbers of output channels.""" + # Create channels + ch_in: Channel[TableChunk] = context.create_channel() + chs_out: list[Channel[TableChunk]] = [ + context.create_channel() for _ in range(num_outputs) + ] + + if num_outputs == 1: + with pytest.raises(ValueError): + fanout(context, ch_in, chs_out, policy) + return + + # Create test messages + messages = [] + for i in range(3): + df = cudf.DataFrame({"x": [i * 10, i * 10 + 1]}) + chunk = TableChunk.from_pylibcudf_table( + cudf_to_pylibcudf_table(df), stream, exclusive_view=False + ) + messages.append(Message(i, chunk)) + + # Create nodes + push_node = push_to_channel(context, ch_in, messages) + fanout_node = fanout(context, ch_in, chs_out, policy) + pull_nodes = [] + outputs = [] + for ch_out in chs_out: + pull_node, output = pull_from_channel(context, ch_out) + pull_nodes.append(pull_node) + outputs.append(output) + + # Run pipeline + with ThreadPoolExecutor(max_workers=1) as executor: + run_streaming_pipeline( + nodes=[push_node, fanout_node, *pull_nodes], + py_executor=executor, + ) + + # Verify all outputs received the messages + for output_idx, output in enumerate(outputs): + results = output.release() + assert len(results) == 3, ( + f"Output {output_idx}: Expected 3 messages, got {len(results)}" + ) + for i in range(3): + assert results[i].sequence_number == i + + +def test_fanout_empty_outputs(context: Context, stream: Stream) -> None: + """Test fanout with empty output list raises value error.""" + ch_in: Channel[TableChunk] = context.create_channel() + with pytest.raises(ValueError): + fanout(context, ch_in, [], FanoutPolicy.BOUNDED) + + +def test_fanout_policy_enum() -> None: + """Test that FanoutPolicy enum has correct values.""" + assert FanoutPolicy.BOUNDED == 0 + assert FanoutPolicy.UNBOUNDED == 1 + assert len(FanoutPolicy) == 2