From feb5b593ef931f175f21dc843213bc3711471dab Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 3 Nov 2025 15:57:27 -0800 Subject: [PATCH 01/43] porting code from mads' PR Signed-off-by: niranda perera --- cpp/CMakeLists.txt | 1 + .../rapidsmpf/streaming/core/fanout.hpp | 75 +++++++++ .../rapidsmpf/streaming/core/message.hpp | 4 +- cpp/src/streaming/core/fanout.cpp | 89 ++++++++++ cpp/tests/CMakeLists.txt | 1 + cpp/tests/streaming/test_fanout.cpp | 152 ++++++++++++++++++ 6 files changed, 320 insertions(+), 2 deletions(-) create mode 100644 cpp/include/rapidsmpf/streaming/core/fanout.hpp create mode 100644 cpp/src/streaming/core/fanout.cpp create mode 100644 cpp/tests/streaming/test_fanout.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9c5ea3d66..9bd0047b8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -182,6 +182,7 @@ if(RAPIDSMPF_HAVE_STREAMING) src/streaming/coll/allgather.cpp src/streaming/coll/shuffler.cpp src/streaming/core/context.cpp + src/streaming/core/fanout.cpp src/streaming/core/leaf_node.cpp src/streaming/core/node.cpp src/streaming/cudf/partition.cpp diff --git a/cpp/include/rapidsmpf/streaming/core/fanout.hpp b/cpp/include/rapidsmpf/streaming/core/fanout.hpp new file mode 100644 index 000000000..0331e0bf0 --- /dev/null +++ b/cpp/include/rapidsmpf/streaming/core/fanout.hpp @@ -0,0 +1,75 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: Apache-2.0 + */ + + #pragma once + + #include + #include + #include + + namespace rapidsmpf::streaming::node { + + /** + * @brief Fanout policy controlling how messages are propagated. + */ + enum class FanoutPolicy : int { + /** + * @brief Process messages as they arrive and immediately forward them. + * + * Messages are forwarded as soon as they are received from the input channel. + * The next message is not processed until all output channels have completed + * sending the current one, ensuring backpressure and synchronized flow. + */ + BOUNDED, + + /** + * @brief Forward messages without enforcing backpressure. + * + * In this mode, messages may be accumulated internally before being + * broadcast, or they may be forwarded immediately depending on the + * implementation and downstream consumption rate. + * + * This mode disables coordinated backpressure between outputs, allowing + * consumers to process at independent rates, but can lead to unbounded + * buffering and increased memory usage. + * + * @note Consumers might not receive any messages until *all* upstream + * messages have been sent, depending on the implementation and buffering + * strategy. + */ + UNBOUNDED, + }; + + /** + * @brief Broadcast messages from one input channel to multiple output channels. + * + * The node continuously receives messages from the input channel and forwards + * them to all output channels according to the selected fanout policy, see + * ::FanoutPolicy. + * + * Each output channel receives a shallow copy of the same message; no payload + * data is duplicated. All copies share the same underlying payload, ensuring + * zero-copy broadcast semantics. + * + * @param ctx The node context to use. + * @param ch_in Input channel from which messages are received. + * @param chs_out Output channels to which messages are broadcast. + * @param policy The fanout strategy to use (see ::FanoutPolicy). + * + * @return Streaming node representing the fanout operation. + * + * @throws std::invalid_argument If an unknown fanout policy is specified. + * + * @note Since messages are shallow-copied, releasing a payload (`release()`) + * is only valid on messages that hold exclusive ownership of the payload. + */ + Node fanout( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out, + FanoutPolicy policy + ); + + } // namespace rapidsmpf::streaming::node \ No newline at end of file diff --git a/cpp/include/rapidsmpf/streaming/core/message.hpp b/cpp/include/rapidsmpf/streaming/core/message.hpp index e2631f40c..05b819b40 100644 --- a/cpp/include/rapidsmpf/streaming/core/message.hpp +++ b/cpp/include/rapidsmpf/streaming/core/message.hpp @@ -218,7 +218,7 @@ class Message { * * @throws std::invalid_argument if the message does not support `content_size`. */ - [[nodiscard]] std::pair content_size(MemoryType mem_type) { + [[nodiscard]] std::pair content_size(MemoryType mem_type) const { RAPIDSMPF_EXPECTS( callbacks_.content_size, "message doesn't support `content_size`", @@ -235,7 +235,7 @@ class Message { * * @see copy() */ - [[nodiscard]] size_t copy_cost() { + [[nodiscard]] size_t copy_cost() const { size_t ret = 0; for (MemoryType mem_type : MEMORY_TYPES) { ret += content_size(mem_type).first; diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp new file mode 100644 index 000000000..0b9b6733f --- /dev/null +++ b/cpp/src/streaming/core/fanout.cpp @@ -0,0 +1,89 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include + +namespace rapidsmpf::streaming::node { +namespace { +/** + * @brief Asynchronously send a message to multiple output channels. + * + * @param msg The message to broadcast. Each channel receives a shallow + * copy of the original message. + * @param chs_out The set of output channels to which the message is sent. + */ +Node send_to_channels( + Context* ctx, Message const& msg, std::vector>& chs_out +) { + std::vector> tasks; + tasks.reserve(chs_out.size()); + for (auto& ch_out : chs_out) { + // do a reservation for each copy, so that it will fallback to host memory if + // needed + auto res = ctx->br()->reserve_or_fail(msg.copy_cost()); + tasks.push_back(ch_out->send(msg.copy(ctx->br(), res))); + } + coro_results(co_await coro::when_all(std::move(tasks))); +} +} // namespace + +Node fanout( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out, + FanoutPolicy policy +) { + ShutdownAtExit c1{ch_in}; + ShutdownAtExit c2{chs_out}; + co_await ctx->executor()->schedule(); + + switch (policy) { + case FanoutPolicy::BOUNDED: + while (true) { + auto msg = co_await ch_in->receive(); + if (msg.empty()) { + break; + } + co_await send_to_channels(ctx.get(), msg, chs_out); + } + break; + case FanoutPolicy::UNBOUNDED: + // TODO: Instead of buffering all messages before broadcasting, + // stream them directly by giving each output channel its own + // `coro::queue` and spawning a coroutine per channel that + // sends from that queue. + { + // First we receive until the input channel is shutdown. + std::vector messages; + while (true) { + auto msg = co_await ch_in->receive(); + if (msg.empty()) { + break; + } + messages.push_back(std::move(msg)); + } + // Then we send each input message to all output channels. + for (auto& msg : messages) { + co_await send_to_channels(ctx.get(), msg, chs_out); + } + break; + } + default: + RAPIDSMPF_FAIL("Unknown broadcast policy", std::invalid_argument); + } + + // Finally, we drain all output channels. + std::vector tasks; + tasks.reserve(chs_out.size()); + for (auto& ch_out : chs_out) { + tasks.push_back(ch_out->drain(ctx->executor())); + } + coro_results(co_await coro::when_all(std::move(tasks))); +} + +} // namespace rapidsmpf::streaming::node \ No newline at end of file diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 522ec9cd2..724ea33f7 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -72,6 +72,7 @@ if(RAPIDSMPF_HAVE_STREAMING) test_sources INTERFACE streaming/test_allgather.cpp streaming/test_error_handling.cpp + streaming/test_fanout.cpp streaming/test_leaf_node.cpp streaming/test_lineariser.cpp streaming/test_message.cpp diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp new file mode 100644 index 000000000..358ba4167 --- /dev/null +++ b/cpp/tests/streaming/test_fanout.cpp @@ -0,0 +1,152 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: Apache-2.0 + */ + + #include + #include + + #include + + #include + #include + #include + + #include "base_streaming_fixture.hpp" + + using namespace rapidsmpf; + using namespace rapidsmpf::streaming; + namespace node = rapidsmpf::streaming::node; + using rapidsmpf::streaming::node::FanoutPolicy; + + using StreamingFanout = BaseStreamingFixture; + + namespace { + + /** + * @brief Helper to make a sequence of Message with values [0, n). + */ + std::vector make_int_inputs(int n) { + std::vector inputs; + inputs.reserve(n); + for (int i = 0; i < n; ++i) { + inputs.emplace_back(i, std::make_unique(i)); + } + return inputs; + } + + } // namespace + + TEST_F(StreamingFanout, Bounded) { + int const num_msgs = 10; + + // Prepare inputs + auto inputs = make_int_inputs(num_msgs); + + // Create pipeline + std::vector outs1, outs2, outs3; + { + std::vector nodes; + + auto in = std::make_shared(); + nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); + + auto out1 = std::make_shared(); + auto out2 = std::make_shared(); + auto out3 = std::make_shared(); + nodes.push_back(node::fanout(ctx, in, {out1, out2, out3}, FanoutPolicy::BOUNDED)); + nodes.push_back(node::pull_from_channel(ctx, out1, outs1)); + nodes.push_back(node::pull_from_channel(ctx, out2, outs2)); + nodes.push_back(node::pull_from_channel(ctx, out3, outs3)); + + run_streaming_pipeline(std::move(nodes)); + } + + // Validate sizes + EXPECT_EQ(outs1.size(), static_cast(num_msgs)); + EXPECT_EQ(outs2.size(), static_cast(num_msgs)); + EXPECT_EQ(outs3.size(), static_cast(num_msgs)); + + // Validate ordering/content and that shallow copies share the same underlying object + for (int i = 0; i < num_msgs; ++i) { + // Same address means same shared payload (shallow copy) + EXPECT_EQ( + std::addressof(outs1[i].get()), std::addressof(outs2[i].get()) + ); + EXPECT_EQ( + std::addressof(outs1[i].get()), std::addressof(outs3[i].get()) + ); + EXPECT_EQ(outs1[i].get(), i); + EXPECT_EQ(outs2[i].get(), i); + EXPECT_EQ(outs3[i].get(), i); + } + + // release() semantics: requires sole ownership. + // For each triplet, drop two references, then release from the remaining one. + for (int i = 0; i < num_msgs; ++i) { + // Holding 3 references -> release must fail + EXPECT_THROW(std::ignore = outs1[i].release(), std::invalid_argument); + + // Make outs1[i] the sole owner + outs2[i].reset(); + outs3[i].reset(); + + // Now release succeeds and yields the expected value + EXPECT_NO_THROW({ + int v = outs1[i].release(); + EXPECT_EQ(v, i); + }); + + // After release, the message is empty + EXPECT_TRUE(outs1[i].empty()); + } + } + + TEST_F(StreamingFanout, Unbounded) { + int const num_msgs = 7; + + auto inputs = make_int_inputs(num_msgs); + + std::vector outs1, outs2; + { + std::vector nodes; + + auto in = std::make_shared(); + auto out1 = std::make_shared(); + auto out2 = std::make_shared(); + + nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); + + // UNBOUNDED policy: buffer all inputs, then broadcast after input closes. + nodes.push_back(node::fanout(ctx, in, {out1, out2}, FanoutPolicy::UNBOUNDED)); + + nodes.push_back(node::pull_from_channel(ctx, out1, outs1)); + nodes.push_back(node::pull_from_channel(ctx, out2, outs2)); + + run_streaming_pipeline(std::move(nodes)); + } + + ASSERT_EQ(outs1.size(), static_cast(num_msgs)); + ASSERT_EQ(outs2.size(), static_cast(num_msgs)); + + // Order and identity must be preserved + for (int i = 0; i < num_msgs; ++i) { + EXPECT_EQ(&outs1[i].get(), &outs2[i].get()); + EXPECT_EQ(outs1[i].get(), i); + EXPECT_EQ(outs2[i].get(), i); + } + + // Release semantics: with two refs, release must fail until one is reset. + for (int i = 0; i < num_msgs; ++i) { + EXPECT_THROW(std::ignore = outs1[i].release(), std::invalid_argument); + + // Make outs1[i] sole owner + outs2[i].reset(); + + EXPECT_NO_THROW({ + int v = outs1[i].release(); + EXPECT_EQ(v, i); + }); + EXPECT_TRUE(outs1[i].empty()); + } + } \ No newline at end of file From 6ea737c157555fcdd3912d65021514abe159cd67 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 4 Nov 2025 17:03:41 -0800 Subject: [PATCH 02/43] WIP Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 123 +++++++++++++++++++++++++++++- 1 file changed, 119 insertions(+), 4 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 0b9b6733f..a3e9062f9 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include + #include #include @@ -30,6 +32,123 @@ Node send_to_channels( } coro_results(co_await coro::when_all(std::move(tasks))); } + +struct UnboundedFanoutState { + UnboundedFanoutState(std::vector>&& chs_out_) + : chs_out{std::move(chs_out_)}, + ch_next_idx{chs_out.size(), 0}, + ch_data_avail{chs_out.size(), {}}, + send_tasks{chs_out.size()} {} + + coro::task receive_done() { + auto lock = co_await mtx.scoped_lock(); + n_msgs = chs_out.size(); + } + + [[nodiscard]] constexpr bool all_received() const { + return n_msgs != std::numeric_limits::max(); + } + + [[nodiscard]] constexpr bool all_sent(size_t i) const { + return all_received() && ch_next_idx[i] == n_msgs; + } + + // [[nodiscard]] constexpr size_t last_completed_idx() const { + // return std::ranges::min(ch_next_idx); + // } + + // thread-safe data for each send task + std::vector> chs_out; + std::vector ch_next_idx; // values are strictly increasing + std::vector ch_data_avail; + + std::vector send_tasks; + + std::vector recv_messages; + + coro::mutex mtx; + coro::condition_variable cv; + size_t n_msgs{std::numeric_limits::max()}; +}; + +Node send_task(Context* ctx, UnboundedFanoutState& state, size_t i) { + co_await ctx->executor()->schedule(); + + // co_await state.data_avail; // wait for data to be available + + while (true) { + // wait for the data to be available + co_await state.ch_data_avail[i]; + + if (state.all_sent(i)) { + // all messages have been sent, nothing else to do + break; + } + + auto const& msg = state.recv_messages[state.ch_next_idx[i]]; + // copy msg + // msg.content_size + + { + auto lock = state.mtx.scoped_lock(); + state.ch_next_idx[i]++; + } + co_await state.cv.notify_one(); + + + // + if (state.ch_next_idx[i] == state.recv_messages.size() && !state.all_received()) { + state.ch_data_avail[i].reset(); + } + } +} + +Node unbounded_fanout( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out +) { + ShutdownAtExit c{ch_in}; + ShutdownAtExit c2{chs_out}; + co_await ctx->executor()->schedule(); + + UnboundedFanoutState state(std::move(chs_out)); + + size_t purge_idx = 0; + while (true) { + { + auto lock = co_await state.mtx.scoped_lock(); + co_await state.cv.wait(lock, [&] { + return state.recv_messages.size() <= std::ranges::max(state.ch_next_idx); + }); + } + + // n_msgs is only set by this task. So, reading w/o a lock is safe. + if (state.n_msgs == std::numeric_limits::max()) { + auto msg = co_await ch_in->receive(); + auto lock = co_await state.mtx.scoped_lock(); + if (msg.empty()) { + // no more messages to receive + state.n_msgs = state.recv_messages.size(); + } else { + state.recv_messages.push_back(std::move(msg)); + lock.unlock(); + + for (auto& event : state.ch_data_avail) { + event.set(); + } + } + } + + size_t last_completed_idx = std::ranges::min(state.ch_next_idx); + while (purge_idx <= last_completed_idx) { + state.ch_data_avail[purge_idx].reset(); + purge_idx++; + } + } +} + + } // namespace Node fanout( @@ -53,10 +172,6 @@ Node fanout( } break; case FanoutPolicy::UNBOUNDED: - // TODO: Instead of buffering all messages before broadcasting, - // stream them directly by giving each output channel its own - // `coro::queue` and spawning a coroutine per channel that - // sends from that queue. { // First we receive until the input channel is shutdown. std::vector messages; From a7254bab34c88a39d537c4ff904c6d24f85a631a Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 5 Nov 2025 11:55:13 -0800 Subject: [PATCH 03/43] adding fanout Signed-off-by: niranda perera --- .../rapidsmpf/streaming/core/fanout.hpp | 2 +- cpp/src/streaming/core/fanout.cpp | 201 +++++++++--------- 2 files changed, 102 insertions(+), 101 deletions(-) diff --git a/cpp/include/rapidsmpf/streaming/core/fanout.hpp b/cpp/include/rapidsmpf/streaming/core/fanout.hpp index 0331e0bf0..e9fe93842 100644 --- a/cpp/include/rapidsmpf/streaming/core/fanout.hpp +++ b/cpp/include/rapidsmpf/streaming/core/fanout.hpp @@ -14,7 +14,7 @@ /** * @brief Fanout policy controlling how messages are propagated. */ - enum class FanoutPolicy : int { + enum class FanoutPolicy : uint8_t { /** * @brief Process messages as they arrive and immediately forward them. * diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index a3e9062f9..377d0c2af 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -4,6 +4,7 @@ */ #include +#include #include #include @@ -33,119 +34,130 @@ Node send_to_channels( coro_results(co_await coro::when_all(std::move(tasks))); } -struct UnboundedFanoutState { - UnboundedFanoutState(std::vector>&& chs_out_) - : chs_out{std::move(chs_out_)}, - ch_next_idx{chs_out.size(), 0}, - ch_data_avail{chs_out.size(), {}}, - send_tasks{chs_out.size()} {} - - coro::task receive_done() { - auto lock = co_await mtx.scoped_lock(); - n_msgs = chs_out.size(); - } - - [[nodiscard]] constexpr bool all_received() const { - return n_msgs != std::numeric_limits::max(); - } - - [[nodiscard]] constexpr bool all_sent(size_t i) const { - return all_received() && ch_next_idx[i] == n_msgs; - } - - // [[nodiscard]] constexpr size_t last_completed_idx() const { - // return std::ranges::min(ch_next_idx); - // } - - // thread-safe data for each send task - std::vector> chs_out; - std::vector ch_next_idx; // values are strictly increasing - std::vector ch_data_avail; - - std::vector send_tasks; - - std::vector recv_messages; - - coro::mutex mtx; - coro::condition_variable cv; - size_t n_msgs{std::numeric_limits::max()}; -}; - -Node send_task(Context* ctx, UnboundedFanoutState& state, size_t i) { - co_await ctx->executor()->schedule(); - - // co_await state.data_avail; // wait for data to be available +Node unbounded_fo_send_task( + Context& ctx, + std::shared_ptr const& ch_out, + size_t* next_idx, + coro::mutex& mtx, + coro::condition_variable& data_ready, + coro::condition_variable& request_data, + bool const& input_done, + std::vector const& recv_messages +) { + ShutdownAtExit c{ch_out}; + co_await ctx.executor()->schedule(); + size_t end_idx; while (true) { - // wait for the data to be available - co_await state.ch_data_avail[i]; + { + auto lock = co_await mtx.scoped_lock(); + co_await data_ready.wait(lock, [&] { + // irrespective of input_done, update the end_idx to the total number of + // messages + end_idx = recv_messages.size(); + return input_done || *next_idx < end_idx; + }); - if (state.all_sent(i)) { - // all messages have been sent, nothing else to do - break; + if (input_done && *next_idx == end_idx) { + break; + } } - auto const& msg = state.recv_messages[state.ch_next_idx[i]]; - // copy msg - // msg.content_size - - { - auto lock = state.mtx.scoped_lock(); - state.ch_next_idx[i]++; + // now we can copy & send messages in indices [next_idx, end_idx) + for (size_t i = *next_idx; i < end_idx; i++) { + auto const& msg = recv_messages[i]; + auto res = ctx.br()->reserve_or_fail(msg.copy_cost()); + co_await ch_out->send(msg.copy(ctx.br(), res)); } - co_await state.cv.notify_one(); - - // - if (state.ch_next_idx[i] == state.recv_messages.size() && !state.all_received()) { - state.ch_data_avail[i].reset(); + // now next_idx can be updated to end_idx, and if !input_done, we need to request + // parent task for more data + auto lock = co_await mtx.scoped_lock(); + *next_idx = end_idx; + if (input_done) { + break; + } else { + lock.unlock(); + co_await request_data.notify_one(); } } + + // channels will be drained by the caller } Node unbounded_fanout( - std::shared_ptr ctx, - std::shared_ptr ch_in, - std::vector> chs_out + Context& ctx, + std::shared_ptr const& ch_in, + std::vector> const& chs_out ) { ShutdownAtExit c{ch_in}; ShutdownAtExit c2{chs_out}; - co_await ctx->executor()->schedule(); + co_await ctx.executor()->schedule(); + - UnboundedFanoutState state(std::move(chs_out)); + coro::mutex mtx; + coro::condition_variable data_ready; + coro::condition_variable request_data; + bool input_done{false}; + std::vector recv_messages; + + std::vector ch_next_idx{chs_out.size(), 0}; + std::vector tasks; + tasks.reserve(chs_out.size()); + for (size_t i = 0; i < chs_out.size(); i++) { + tasks.emplace_back(unbounded_fo_send_task( + ctx, + chs_out[i], + &ch_next_idx[i], + mtx, + data_ready, + request_data, + input_done, + recv_messages + )); + } size_t purge_idx = 0; - while (true) { + // input_done is only set by this task, so reading without lock is safe here + while (!input_done) { { - auto lock = co_await state.mtx.scoped_lock(); - co_await state.cv.wait(lock, [&] { - return state.recv_messages.size() <= std::ranges::max(state.ch_next_idx); + auto lock = co_await mtx.scoped_lock(); + co_await request_data.wait(lock, [&] { + return std::ranges::any_of(ch_next_idx, [&](size_t next_idx) { + return recv_messages.size() >= next_idx; + }); }); } - // n_msgs is only set by this task. So, reading w/o a lock is safe. - if (state.n_msgs == std::numeric_limits::max()) { - auto msg = co_await ch_in->receive(); - auto lock = co_await state.mtx.scoped_lock(); + // receive a message from the input channel + auto msg = co_await ch_in->receive(); + + { // relock mtx to update input_done/ recv_messages + auto lock = co_await mtx.scoped_lock(); if (msg.empty()) { - // no more messages to receive - state.n_msgs = state.recv_messages.size(); + input_done = true; } else { - state.recv_messages.push_back(std::move(msg)); - lock.unlock(); - - for (auto& event : state.ch_data_avail) { - event.set(); - } + recv_messages.emplace_back(std::move(msg)); } } - - size_t last_completed_idx = std::ranges::min(state.ch_next_idx); - while (purge_idx <= last_completed_idx) { - state.ch_data_avail[purge_idx].reset(); + // notify send_tasks to copy & send messages + co_await data_ready.notify_all(); + + // purge completed send_tasks + // intentionally not locking the mtx here, because we only need to know a + // lower-bound on the last completed idx (ch_next_idx values are monotonically + // increasing) + size_t last_completed_idx = std::ranges::min(ch_next_idx) - 1; + while (purge_idx < last_completed_idx) { + recv_messages[purge_idx].reset(); purge_idx++; } } + + // Note: there will be some messages to be purged after the loop exits, but we don't + // need to do anything about them here + + coro_results(co_await coro::when_all(std::move(tasks))); } @@ -173,19 +185,7 @@ Node fanout( break; case FanoutPolicy::UNBOUNDED: { - // First we receive until the input channel is shutdown. - std::vector messages; - while (true) { - auto msg = co_await ch_in->receive(); - if (msg.empty()) { - break; - } - messages.push_back(std::move(msg)); - } - // Then we send each input message to all output channels. - for (auto& msg : messages) { - co_await send_to_channels(ctx.get(), msg, chs_out); - } + co_await unbounded_fanout(*ctx, ch_in, chs_out); break; } default: @@ -195,9 +195,10 @@ Node fanout( // Finally, we drain all output channels. std::vector tasks; tasks.reserve(chs_out.size()); - for (auto& ch_out : chs_out) { - tasks.push_back(ch_out->drain(ctx->executor())); - } + std::ranges::transform(chs_out, std::back_inserter(tasks), [&](auto& ch_out) { + return ch_out->drain(ctx->executor()); + }); + coro_results(co_await coro::when_all(std::move(tasks))); } From 91bd6c77322019b44917c054e68d4932183093bc Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 5 Nov 2025 13:30:40 -0800 Subject: [PATCH 04/43] working draft Signed-off-by: niranda perera --- .../rapidsmpf/streaming/core/fanout.hpp | 138 ++++----- cpp/src/streaming/core/fanout.cpp | 109 +++---- cpp/tests/streaming/test_fanout.cpp | 266 ++++++++---------- 3 files changed, 247 insertions(+), 266 deletions(-) diff --git a/cpp/include/rapidsmpf/streaming/core/fanout.hpp b/cpp/include/rapidsmpf/streaming/core/fanout.hpp index e9fe93842..01e027556 100644 --- a/cpp/include/rapidsmpf/streaming/core/fanout.hpp +++ b/cpp/include/rapidsmpf/streaming/core/fanout.hpp @@ -3,73 +3,73 @@ * SPDX-License-Identifier: Apache-2.0 */ - #pragma once +#pragma once - #include - #include - #include - - namespace rapidsmpf::streaming::node { - - /** - * @brief Fanout policy controlling how messages are propagated. - */ - enum class FanoutPolicy : uint8_t { - /** - * @brief Process messages as they arrive and immediately forward them. - * - * Messages are forwarded as soon as they are received from the input channel. - * The next message is not processed until all output channels have completed - * sending the current one, ensuring backpressure and synchronized flow. - */ - BOUNDED, - - /** - * @brief Forward messages without enforcing backpressure. - * - * In this mode, messages may be accumulated internally before being - * broadcast, or they may be forwarded immediately depending on the - * implementation and downstream consumption rate. - * - * This mode disables coordinated backpressure between outputs, allowing - * consumers to process at independent rates, but can lead to unbounded - * buffering and increased memory usage. - * - * @note Consumers might not receive any messages until *all* upstream - * messages have been sent, depending on the implementation and buffering - * strategy. - */ - UNBOUNDED, - }; - - /** - * @brief Broadcast messages from one input channel to multiple output channels. - * - * The node continuously receives messages from the input channel and forwards - * them to all output channels according to the selected fanout policy, see - * ::FanoutPolicy. - * - * Each output channel receives a shallow copy of the same message; no payload - * data is duplicated. All copies share the same underlying payload, ensuring - * zero-copy broadcast semantics. - * - * @param ctx The node context to use. - * @param ch_in Input channel from which messages are received. - * @param chs_out Output channels to which messages are broadcast. - * @param policy The fanout strategy to use (see ::FanoutPolicy). - * - * @return Streaming node representing the fanout operation. - * - * @throws std::invalid_argument If an unknown fanout policy is specified. - * - * @note Since messages are shallow-copied, releasing a payload (`release()`) - * is only valid on messages that hold exclusive ownership of the payload. - */ - Node fanout( - std::shared_ptr ctx, - std::shared_ptr ch_in, - std::vector> chs_out, - FanoutPolicy policy - ); - - } // namespace rapidsmpf::streaming::node \ No newline at end of file +#include +#include +#include + +namespace rapidsmpf::streaming::node { + +/** + * @brief Fanout policy controlling how messages are propagated. + */ +enum class FanoutPolicy : uint8_t { + /** + * @brief Process messages as they arrive and immediately forward them. + * + * Messages are forwarded as soon as they are received from the input channel. + * The next message is not processed until all output channels have completed + * sending the current one, ensuring backpressure and synchronized flow. + */ + BOUNDED, + + /** + * @brief Forward messages without enforcing backpressure. + * + * In this mode, messages may be accumulated internally before being + * broadcast, or they may be forwarded immediately depending on the + * implementation and downstream consumption rate. + * + * This mode disables coordinated backpressure between outputs, allowing + * consumers to process at independent rates, but can lead to unbounded + * buffering and increased memory usage. + * + * @note Consumers might not receive any messages until *all* upstream + * messages have been sent, depending on the implementation and buffering + * strategy. + */ + UNBOUNDED, +}; + +/** + * @brief Broadcast messages from one input channel to multiple output channels. + * + * The node continuously receives messages from the input channel and forwards + * them to all output channels according to the selected fanout policy, see + * ::FanoutPolicy. + * + * Each output channel receives a shallow copy of the same message; no payload + * data is duplicated. All copies share the same underlying payload, ensuring + * zero-copy broadcast semantics. + * + * @param ctx The node context to use. + * @param ch_in Input channel from which messages are received. + * @param chs_out Output channels to which messages are broadcast. + * @param policy The fanout strategy to use (see ::FanoutPolicy). + * + * @return Streaming node representing the fanout operation. + * + * @throws std::invalid_argument If an unknown fanout policy is specified. + * + * @note Since messages are shallow-copied, releasing a payload (`release()`) + * is only valid on messages that hold exclusive ownership of the payload. + */ +Node fanout( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out, + FanoutPolicy policy +); + +} // namespace rapidsmpf::streaming::node diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 377d0c2af..7d274d359 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -29,24 +29,51 @@ Node send_to_channels( // do a reservation for each copy, so that it will fallback to host memory if // needed auto res = ctx->br()->reserve_or_fail(msg.copy_cost()); - tasks.push_back(ch_out->send(msg.copy(ctx->br(), res))); + tasks.push_back(ch_out->send(msg.copy(res))); } coro_results(co_await coro::when_all(std::move(tasks))); } +Node bounded_fanout( + std::shared_ptr&& ctx, + std::shared_ptr&& ch_in, + std::vector>&& chs_out +) { + ShutdownAtExit c1{ch_in}; + ShutdownAtExit c2{chs_out}; + co_await ctx->executor()->schedule(); + + while (true) { + auto msg = co_await ch_in->receive(); + if (msg.empty()) { + break; + } + co_await send_to_channels(ctx.get(), msg, chs_out); + } + + // Finally, we drain all output channels. + std::vector tasks; + tasks.reserve(chs_out.size()); + std::ranges::transform(chs_out, std::back_inserter(tasks), [&](auto& ch_out) { + return ch_out->drain(ctx->executor()); + }); + + coro_results(co_await coro::when_all(std::move(tasks))); +} + Node unbounded_fo_send_task( Context& ctx, - std::shared_ptr const& ch_out, - size_t* next_idx, + size_t idx, + std::vector> const& chs_out, + std::vector& ch_next_idx, coro::mutex& mtx, coro::condition_variable& data_ready, coro::condition_variable& request_data, bool const& input_done, std::vector const& recv_messages ) { - ShutdownAtExit c{ch_out}; + ShutdownAtExit c{chs_out[idx]}; co_await ctx.executor()->schedule(); - size_t end_idx; while (true) { { @@ -55,26 +82,28 @@ Node unbounded_fo_send_task( // irrespective of input_done, update the end_idx to the total number of // messages end_idx = recv_messages.size(); - return input_done || *next_idx < end_idx; + return input_done || ch_next_idx[idx] < end_idx; }); - - if (input_done && *next_idx == end_idx) { + if (input_done && ch_next_idx[idx] == end_idx) { break; } } // now we can copy & send messages in indices [next_idx, end_idx) - for (size_t i = *next_idx; i < end_idx; i++) { + for (size_t i = ch_next_idx[idx]; i < end_idx; i++) { auto const& msg = recv_messages[i]; + + // make reservations for each message so that it will fallback to host memory + // if needed auto res = ctx.br()->reserve_or_fail(msg.copy_cost()); - co_await ch_out->send(msg.copy(ctx.br(), res)); + co_await chs_out[idx]->send(msg.copy(res)); } // now next_idx can be updated to end_idx, and if !input_done, we need to request // parent task for more data auto lock = co_await mtx.scoped_lock(); - *next_idx = end_idx; - if (input_done) { + ch_next_idx[idx] = end_idx; + if (input_done && ch_next_idx[idx] == end_idx) { break; } else { lock.unlock(); @@ -82,17 +111,17 @@ Node unbounded_fo_send_task( } } - // channels will be drained by the caller + co_await chs_out[idx]->drain(ctx.executor()); } Node unbounded_fanout( - Context& ctx, - std::shared_ptr const& ch_in, - std::vector> const& chs_out + std::shared_ptr&& ctx, + std::shared_ptr&& ch_in, + std::vector>&& chs_out ) { ShutdownAtExit c{ch_in}; ShutdownAtExit c2{chs_out}; - co_await ctx.executor()->schedule(); + co_await ctx->executor()->schedule(); coro::mutex mtx; @@ -101,14 +130,16 @@ Node unbounded_fanout( bool input_done{false}; std::vector recv_messages; - std::vector ch_next_idx{chs_out.size(), 0}; + std::vector ch_next_idx(chs_out.size(), 0); + std::vector tasks; tasks.reserve(chs_out.size()); for (size_t i = 0; i < chs_out.size(); i++) { tasks.emplace_back(unbounded_fo_send_task( - ctx, - chs_out[i], - &ch_next_idx[i], + *ctx, + i, + chs_out, + ch_next_idx, mtx, data_ready, request_data, @@ -129,7 +160,7 @@ Node unbounded_fanout( }); } - // receive a message from the input channel + // receive a message from the input channel auto msg = co_await ch_in->receive(); { // relock mtx to update input_done/ recv_messages @@ -147,8 +178,8 @@ Node unbounded_fanout( // intentionally not locking the mtx here, because we only need to know a // lower-bound on the last completed idx (ch_next_idx values are monotonically // increasing) - size_t last_completed_idx = std::ranges::min(ch_next_idx) - 1; - while (purge_idx < last_completed_idx) { + size_t last_completed_idx = std::ranges::min(ch_next_idx); + while (purge_idx + 1 < last_completed_idx) { recv_messages[purge_idx].reset(); purge_idx++; } @@ -156,7 +187,6 @@ Node unbounded_fanout( // Note: there will be some messages to be purged after the loop exits, but we don't // need to do anything about them here - coro_results(co_await coro::when_all(std::move(tasks))); } @@ -169,37 +199,16 @@ Node fanout( std::vector> chs_out, FanoutPolicy policy ) { - ShutdownAtExit c1{ch_in}; - ShutdownAtExit c2{chs_out}; - co_await ctx->executor()->schedule(); - switch (policy) { case FanoutPolicy::BOUNDED: - while (true) { - auto msg = co_await ch_in->receive(); - if (msg.empty()) { - break; - } - co_await send_to_channels(ctx.get(), msg, chs_out); - } + co_await bounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); break; case FanoutPolicy::UNBOUNDED: - { - co_await unbounded_fanout(*ctx, ch_in, chs_out); - break; - } + co_await unbounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); + break; default: RAPIDSMPF_FAIL("Unknown broadcast policy", std::invalid_argument); } - - // Finally, we drain all output channels. - std::vector tasks; - tasks.reserve(chs_out.size()); - std::ranges::transform(chs_out, std::back_inserter(tasks), [&](auto& ch_out) { - return ch_out->drain(ctx->executor()); - }); - - coro_results(co_await coro::when_all(std::move(tasks))); } -} // namespace rapidsmpf::streaming::node \ No newline at end of file +} // namespace rapidsmpf::streaming::node diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index 358ba4167..fd3e863c9 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -3,150 +3,122 @@ * SPDX-License-Identifier: Apache-2.0 */ - #include - #include - - #include - - #include - #include - #include - - #include "base_streaming_fixture.hpp" - - using namespace rapidsmpf; - using namespace rapidsmpf::streaming; - namespace node = rapidsmpf::streaming::node; - using rapidsmpf::streaming::node::FanoutPolicy; - - using StreamingFanout = BaseStreamingFixture; - - namespace { - - /** - * @brief Helper to make a sequence of Message with values [0, n). - */ - std::vector make_int_inputs(int n) { - std::vector inputs; - inputs.reserve(n); - for (int i = 0; i < n; ++i) { - inputs.emplace_back(i, std::make_unique(i)); - } - return inputs; - } - - } // namespace - - TEST_F(StreamingFanout, Bounded) { - int const num_msgs = 10; - - // Prepare inputs - auto inputs = make_int_inputs(num_msgs); - - // Create pipeline - std::vector outs1, outs2, outs3; - { - std::vector nodes; - - auto in = std::make_shared(); - nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); - - auto out1 = std::make_shared(); - auto out2 = std::make_shared(); - auto out3 = std::make_shared(); - nodes.push_back(node::fanout(ctx, in, {out1, out2, out3}, FanoutPolicy::BOUNDED)); - nodes.push_back(node::pull_from_channel(ctx, out1, outs1)); - nodes.push_back(node::pull_from_channel(ctx, out2, outs2)); - nodes.push_back(node::pull_from_channel(ctx, out3, outs3)); - - run_streaming_pipeline(std::move(nodes)); - } - - // Validate sizes - EXPECT_EQ(outs1.size(), static_cast(num_msgs)); - EXPECT_EQ(outs2.size(), static_cast(num_msgs)); - EXPECT_EQ(outs3.size(), static_cast(num_msgs)); - - // Validate ordering/content and that shallow copies share the same underlying object - for (int i = 0; i < num_msgs; ++i) { - // Same address means same shared payload (shallow copy) - EXPECT_EQ( - std::addressof(outs1[i].get()), std::addressof(outs2[i].get()) - ); - EXPECT_EQ( - std::addressof(outs1[i].get()), std::addressof(outs3[i].get()) - ); - EXPECT_EQ(outs1[i].get(), i); - EXPECT_EQ(outs2[i].get(), i); - EXPECT_EQ(outs3[i].get(), i); - } - - // release() semantics: requires sole ownership. - // For each triplet, drop two references, then release from the remaining one. - for (int i = 0; i < num_msgs; ++i) { - // Holding 3 references -> release must fail - EXPECT_THROW(std::ignore = outs1[i].release(), std::invalid_argument); - - // Make outs1[i] the sole owner - outs2[i].reset(); - outs3[i].reset(); - - // Now release succeeds and yields the expected value - EXPECT_NO_THROW({ - int v = outs1[i].release(); - EXPECT_EQ(v, i); - }); - - // After release, the message is empty - EXPECT_TRUE(outs1[i].empty()); - } - } - - TEST_F(StreamingFanout, Unbounded) { - int const num_msgs = 7; - - auto inputs = make_int_inputs(num_msgs); - - std::vector outs1, outs2; - { - std::vector nodes; - - auto in = std::make_shared(); - auto out1 = std::make_shared(); - auto out2 = std::make_shared(); - - nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); - - // UNBOUNDED policy: buffer all inputs, then broadcast after input closes. - nodes.push_back(node::fanout(ctx, in, {out1, out2}, FanoutPolicy::UNBOUNDED)); - - nodes.push_back(node::pull_from_channel(ctx, out1, outs1)); - nodes.push_back(node::pull_from_channel(ctx, out2, outs2)); - - run_streaming_pipeline(std::move(nodes)); - } - - ASSERT_EQ(outs1.size(), static_cast(num_msgs)); - ASSERT_EQ(outs2.size(), static_cast(num_msgs)); - - // Order and identity must be preserved - for (int i = 0; i < num_msgs; ++i) { - EXPECT_EQ(&outs1[i].get(), &outs2[i].get()); - EXPECT_EQ(outs1[i].get(), i); - EXPECT_EQ(outs2[i].get(), i); - } - - // Release semantics: with two refs, release must fail until one is reset. - for (int i = 0; i < num_msgs; ++i) { - EXPECT_THROW(std::ignore = outs1[i].release(), std::invalid_argument); - - // Make outs1[i] sole owner - outs2[i].reset(); - - EXPECT_NO_THROW({ - int v = outs1[i].release(); - EXPECT_EQ(v, i); - }); - EXPECT_TRUE(outs1[i].empty()); - } - } \ No newline at end of file +#include +#include + +#include + +#include +#include +#include +#include + +#include "base_streaming_fixture.hpp" + +#include + +using namespace rapidsmpf; +using namespace rapidsmpf::streaming; +namespace node = rapidsmpf::streaming::node; +using rapidsmpf::streaming::node::FanoutPolicy; + +using StreamingFanout = BaseStreamingFixture; + +namespace { + +/** + * @brief Helper to make a sequence of Message with values [0, n). + */ +std::vector make_int_inputs(int n) { + std::vector inputs; + inputs.reserve(n); + for (int i = 0; i < n; ++i) { + inputs.emplace_back( + i, + std::make_unique(i), + ContentDescription{}, + [](Message const& msg, MemoryReservation&) { + return Message{ + msg.sequence_number(), + std::make_unique(msg.get()), + ContentDescription{} + }; + } + ); + } + return inputs; +} + +} // namespace + +TEST_F(StreamingFanout, Bounded) { + int const num_msgs = 10; + + // Prepare inputs + auto inputs = make_int_inputs(num_msgs); + + // Create pipeline + std::vector outs1, outs2, outs3; + { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); + + auto out1 = ctx->create_channel(); + auto out2 = ctx->create_channel(); + auto out3 = ctx->create_channel(); + nodes.push_back(node::fanout(ctx, in, {out1, out2, out3}, FanoutPolicy::BOUNDED)); + nodes.push_back(node::pull_from_channel(ctx, out1, outs1)); + nodes.push_back(node::pull_from_channel(ctx, out2, outs2)); + nodes.push_back(node::pull_from_channel(ctx, out3, outs3)); + + run_streaming_pipeline(std::move(nodes)); + } + + // Validate sizes + EXPECT_EQ(outs1.size(), static_cast(num_msgs)); + EXPECT_EQ(outs2.size(), static_cast(num_msgs)); + EXPECT_EQ(outs3.size(), static_cast(num_msgs)); + + // Validate ordering/content and that shallow copies share the same underlying object + for (int i = 0; i < num_msgs; ++i) { + EXPECT_EQ(outs1[i].get(), i); + EXPECT_EQ(outs2[i].get(), i); + EXPECT_EQ(outs3[i].get(), i); + } +} + +TEST_F(StreamingFanout, Unbounded) { + int const num_msgs = 7; + + auto inputs = make_int_inputs(num_msgs); + + std::vector outs1, outs2; + { + std::vector nodes; + + auto in = ctx->create_channel(); + auto out1 = ctx->create_channel(); + auto out2 = ctx->create_channel(); + + nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); + + // UNBOUNDED policy: buffer all inputs, then broadcast after input closes. + nodes.push_back(node::fanout(ctx, in, {out1, out2}, FanoutPolicy::UNBOUNDED)); + + nodes.push_back(node::pull_from_channel(ctx, out1, outs1)); + nodes.push_back(node::pull_from_channel(ctx, out2, outs2)); + + run_streaming_pipeline(std::move(nodes)); + } + + ASSERT_EQ(outs1.size(), static_cast(num_msgs)); + ASSERT_EQ(outs2.size(), static_cast(num_msgs)); + + // Order and identity must be preserved + for (int i = 0; i < num_msgs; ++i) { + EXPECT_EQ(outs1[i].get(), i); + EXPECT_EQ(outs2[i].get(), i); + } +} From a4810a15cec0d9d7fb3d0ac7493a5ec24103e251 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 5 Nov 2025 16:55:21 -0800 Subject: [PATCH 05/43] adding more tests Signed-off-by: niranda perera --- cpp/tests/streaming/test_fanout.cpp | 161 +++++++++++++++++++++------- 1 file changed, 123 insertions(+), 38 deletions(-) diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index fd3e863c9..f6c8df881 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -2,6 +2,7 @@ * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: Apache-2.0 */ +#include #include #include @@ -22,8 +23,6 @@ using namespace rapidsmpf::streaming; namespace node = rapidsmpf::streaming::node; using rapidsmpf::streaming::node::FanoutPolicy; -using StreamingFanout = BaseStreamingFixture; - namespace { /** @@ -49,76 +48,162 @@ std::vector make_int_inputs(int n) { return inputs; } +std::string policy_to_string(FanoutPolicy policy) { + switch (policy) { + case FanoutPolicy::BOUNDED: + return "bounded"; + case FanoutPolicy::UNBOUNDED: + return "unbounded"; + default: + return "unknown"; + } +} + } // namespace -TEST_F(StreamingFanout, Bounded) { - int const num_msgs = 10; +class StreamingFanout + : public BaseStreamingFixture, + public ::testing::WithParamInterface> { + public: + void SetUp() override { + std::tie(policy, num_threads, num_out_chs, num_msgs) = GetParam(); + SetUpWithThreads(num_threads); + } + FanoutPolicy policy; + int num_threads; + int num_out_chs; + int num_msgs; +}; + +INSTANTIATE_TEST_SUITE_P( + StreamingFanout, + StreamingFanout, + ::testing::Combine( + ::testing::Values(FanoutPolicy::BOUNDED, FanoutPolicy::UNBOUNDED), + ::testing::Values(1, 2, 4), // number of threads + ::testing::Values(1, 2, 4), // number of output channels + ::testing::Values(1, 10, 100) // number of messages + ), + [](testing::TestParamInfo const& info) { + return "policy_" + policy_to_string(std::get<0>(info.param)) + "_nthreads_" + + std::to_string(std::get<1>(info.param)) + "_nch_out_" + + std::to_string(std::get<2>(info.param)) + "_nmsgs_" + + std::to_string(std::get<3>(info.param)); + } +); + +TEST_P(StreamingFanout, SinkPerChannel) { // Prepare inputs auto inputs = make_int_inputs(num_msgs); // Create pipeline - std::vector outs1, outs2, outs3; + std::vector> outs(num_out_chs); { std::vector nodes; auto in = ctx->create_channel(); - nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); + nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } - auto out1 = ctx->create_channel(); - auto out2 = ctx->create_channel(); - auto out3 = ctx->create_channel(); - nodes.push_back(node::fanout(ctx, in, {out1, out2, out3}, FanoutPolicy::BOUNDED)); - nodes.push_back(node::pull_from_channel(ctx, out1, outs1)); - nodes.push_back(node::pull_from_channel(ctx, out2, outs2)); - nodes.push_back(node::pull_from_channel(ctx, out3, outs3)); + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + for (int i = 0; i < num_out_chs; ++i) { + nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], outs[i])); + } run_streaming_pipeline(std::move(nodes)); } - // Validate sizes - EXPECT_EQ(outs1.size(), static_cast(num_msgs)); - EXPECT_EQ(outs2.size(), static_cast(num_msgs)); - EXPECT_EQ(outs3.size(), static_cast(num_msgs)); + for (int c = 0; c < num_out_chs; ++c) { + // Validate sizes + EXPECT_EQ(outs[c].size(), static_cast(num_msgs)); + + // Validate ordering/content and that shallow copies share the same underlying + // object + for (int i = 0; i < num_msgs; ++i) { + SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); + EXPECT_EQ(outs[c][i].get(), i); + } + } +} + +namespace { - // Validate ordering/content and that shallow copies share the same underlying object - for (int i = 0; i < num_msgs; ++i) { - EXPECT_EQ(outs1[i].get(), i); - EXPECT_EQ(outs2[i].get(), i); - EXPECT_EQ(outs3[i].get(), i); +enum class ConsumePolicy : uint8_t { + CHANNEL_ORDER, // consume messages in the order of the channels + MESSAGE_ORDER, // consume messages in the order of the messages +}; + +Node many_input_sink( + std::shared_ptr const& ctx, + std::vector>& chs, + ConsumePolicy consume_policy, + std::vector>& outs +) { + ShutdownAtExit c{chs}; + co_await ctx->executor()->schedule(); + + if (consume_policy == ConsumePolicy::CHANNEL_ORDER) { + for (size_t i = 0; i < chs.size(); ++i) { + while (true) { + auto msg = co_await chs[i]->receive(); + if (msg.empty()) { + break; + } + outs[i].push_back(std::move(msg)); + } + } + } else if (consume_policy == ConsumePolicy::MESSAGE_ORDER) { + // while (true) { + // auto msg = co_await chs[0]->receive(); + // outs[0].push_back(msg); + // } } } -TEST_F(StreamingFanout, Unbounded) { - int const num_msgs = 7; +} // namespace + +TEST_P(StreamingFanout, ManyInputSink_ChannelOrder) { + if (policy == FanoutPolicy::BOUNDED) { + GTEST_SKIP() << "Bounded fanout does not support this consume policy"; + } auto inputs = make_int_inputs(num_msgs); - std::vector outs1, outs2; + std::vector> outs(num_out_chs); { std::vector nodes; auto in = ctx->create_channel(); - auto out1 = ctx->create_channel(); - auto out2 = ctx->create_channel(); - nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); - // UNBOUNDED policy: buffer all inputs, then broadcast after input closes. - nodes.push_back(node::fanout(ctx, in, {out1, out2}, FanoutPolicy::UNBOUNDED)); + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } - nodes.push_back(node::pull_from_channel(ctx, out1, outs1)); - nodes.push_back(node::pull_from_channel(ctx, out2, outs2)); + nodes.push_back(node::fanout(ctx, in, out_chs, policy)); + + nodes.push_back(many_input_sink(ctx, out_chs, ConsumePolicy::CHANNEL_ORDER, outs) + ); run_streaming_pipeline(std::move(nodes)); } - ASSERT_EQ(outs1.size(), static_cast(num_msgs)); - ASSERT_EQ(outs2.size(), static_cast(num_msgs)); + for (int c = 0; c < num_out_chs; ++c) { + // Validate sizes + EXPECT_EQ(outs[c].size(), static_cast(num_msgs)); - // Order and identity must be preserved - for (int i = 0; i < num_msgs; ++i) { - EXPECT_EQ(outs1[i].get(), i); - EXPECT_EQ(outs2[i].get(), i); + // Validate ordering/content and that shallow copies share the same underlying + // object + for (int i = 0; i < num_msgs; ++i) { + SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); + EXPECT_EQ(outs[c][i].get(), i); + } } } From 0c5e342a0eca7ad0fd731f21de2f20a0edb2869b Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 6 Nov 2025 17:39:07 -0800 Subject: [PATCH 06/43] adding more tests Signed-off-by: niranda perera --- .../rapidsmpf/streaming/core/channel.hpp | 7 + .../rapidsmpf/streaming/core/context.hpp | 7 + cpp/src/streaming/core/channel.cpp | 6 + cpp/src/streaming/core/context.cpp | 6 +- cpp/src/streaming/core/fanout.cpp | 127 ++++++++++-------- cpp/src/streaming/core/leaf_node.cpp | 4 +- cpp/tests/streaming/test_fanout.cpp | 110 +++++++++++---- 7 files changed, 181 insertions(+), 86 deletions(-) diff --git a/cpp/include/rapidsmpf/streaming/core/channel.hpp b/cpp/include/rapidsmpf/streaming/core/channel.hpp index 7d813ae83..64473dbd1 100644 --- a/cpp/include/rapidsmpf/streaming/core/channel.hpp +++ b/cpp/include/rapidsmpf/streaming/core/channel.hpp @@ -83,6 +83,13 @@ class Channel { */ [[nodiscard]] bool empty() const noexcept; + /** + * @brief Check whether the channel is shut down. + * + * @return True if the channel is shut down. + */ + [[nodiscard]] bool is_shutdown() const noexcept; + private: Channel() = default; coro::ring_buffer rb_; diff --git a/cpp/include/rapidsmpf/streaming/core/context.hpp b/cpp/include/rapidsmpf/streaming/core/context.hpp index a70c2841b..f4eba0dde 100644 --- a/cpp/include/rapidsmpf/streaming/core/context.hpp +++ b/cpp/include/rapidsmpf/streaming/core/context.hpp @@ -75,6 +75,13 @@ class Context { */ [[nodiscard]] std::shared_ptr comm() const noexcept; + /** + * @brief Returns the logger. + * + * @return Reference to the logger. + */ + [[nodiscard]] Communicator::Logger& logger() const noexcept; + /** * @brief Returns the progress thread. * diff --git a/cpp/src/streaming/core/channel.cpp b/cpp/src/streaming/core/channel.cpp index d54d9cc61..962f290eb 100644 --- a/cpp/src/streaming/core/channel.cpp +++ b/cpp/src/streaming/core/channel.cpp @@ -8,6 +8,7 @@ namespace rapidsmpf::streaming { coro::task Channel::send(Message msg) { + RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); auto result = co_await rb_.produce(std::move(msg)); co_return result == coro::ring_buffer_result::produce::produced; } @@ -15,6 +16,7 @@ coro::task Channel::send(Message msg) { coro::task Channel::receive() { auto msg = co_await rb_.consume(); if (msg.has_value()) { + RAPIDSMPF_EXPECTS(!msg->empty(), "received empty message"); co_return std::move(*msg); } else { co_return Message{}; @@ -33,4 +35,8 @@ bool Channel::empty() const noexcept { return rb_.empty(); } +bool Channel::is_shutdown() const noexcept { + return rb_.is_shutdown(); +} + } // namespace rapidsmpf::streaming diff --git a/cpp/src/streaming/core/context.cpp b/cpp/src/streaming/core/context.cpp index 0e1577dc8..2b0699119 100644 --- a/cpp/src/streaming/core/context.cpp +++ b/cpp/src/streaming/core/context.cpp @@ -71,6 +71,10 @@ std::shared_ptr Context::comm() const noexcept { return comm_; } +Communicator::Logger& Context::logger() const noexcept { + return comm_->logger(); +} + std::shared_ptr Context::progress_thread() const noexcept { return progress_thread_; } @@ -88,6 +92,6 @@ std::shared_ptr Context::statistics() const noexcept { } std::shared_ptr Context::create_channel() const noexcept { - return std::unique_ptr(new Channel()); + return std::shared_ptr(new Channel()); } } // namespace rapidsmpf::streaming diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 7d274d359..5c7e60dc5 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -4,6 +4,7 @@ */ #include +#include #include #include @@ -35,30 +36,28 @@ Node send_to_channels( } Node bounded_fanout( - std::shared_ptr&& ctx, - std::shared_ptr&& ch_in, - std::vector>&& chs_out + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out ) { ShutdownAtExit c1{ch_in}; ShutdownAtExit c2{chs_out}; + auto& logger = ctx->logger(); co_await ctx->executor()->schedule(); - while (true) { auto msg = co_await ch_in->receive(); if (msg.empty()) { break; } + co_await send_to_channels(ctx.get(), msg, chs_out); + logger.debug("Sent message ", msg.sequence_number()); } - // Finally, we drain all output channels. - std::vector tasks; - tasks.reserve(chs_out.size()); - std::ranges::transform(chs_out, std::back_inserter(tasks), [&](auto& ch_out) { - return ch_out->drain(ctx->executor()); - }); - - coro_results(co_await coro::when_all(std::move(tasks))); + for (auto& ch : chs_out) { + co_await ch->drain(ctx->executor()); + } + logger.debug("Completed bounded fanout"); } Node unbounded_fo_send_task( @@ -70,40 +69,47 @@ Node unbounded_fo_send_task( coro::condition_variable& data_ready, coro::condition_variable& request_data, bool const& input_done, - std::vector const& recv_messages + std::deque const& recv_messages ) { - ShutdownAtExit c{chs_out[idx]}; + auto& logger = ctx.logger(); + ShutdownAtExit ch_shutdown{chs_out[idx]}; co_await ctx.executor()->schedule(); - size_t end_idx; + + size_t curr_recv_msg_sz; while (true) { { auto lock = co_await mtx.scoped_lock(); co_await data_ready.wait(lock, [&] { // irrespective of input_done, update the end_idx to the total number of // messages - end_idx = recv_messages.size(); - return input_done || ch_next_idx[idx] < end_idx; + curr_recv_msg_sz = recv_messages.size(); + return input_done || ch_next_idx[idx] < curr_recv_msg_sz; }); - if (input_done && ch_next_idx[idx] == end_idx) { + if (input_done && ch_next_idx[idx] == curr_recv_msg_sz) { + // no more messages will be received, and all messages have been sent break; } } // now we can copy & send messages in indices [next_idx, end_idx) - for (size_t i = ch_next_idx[idx]; i < end_idx; i++) { + for (size_t i = ch_next_idx[idx]; i < curr_recv_msg_sz; i++) { auto const& msg = recv_messages[i]; + RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); // make reservations for each message so that it will fallback to host memory // if needed auto res = ctx.br()->reserve_or_fail(msg.copy_cost()); - co_await chs_out[idx]->send(msg.copy(res)); + RAPIDSMPF_EXPECTS( + co_await chs_out[idx]->send(msg.copy(res)), "failed to send message" + ); } + logger.debug("sent ", idx, " [", ch_next_idx[idx], ", ", curr_recv_msg_sz, ")"); // now next_idx can be updated to end_idx, and if !input_done, we need to request // parent task for more data auto lock = co_await mtx.scoped_lock(); - ch_next_idx[idx] = end_idx; - if (input_done && ch_next_idx[idx] == end_idx) { + ch_next_idx[idx] = curr_recv_msg_sz; + if (input_done && ch_next_idx[idx] == recv_messages.size()) { break; } else { lock.unlock(); @@ -112,50 +118,61 @@ Node unbounded_fo_send_task( } co_await chs_out[idx]->drain(ctx.executor()); + logger.debug("Send task ", idx, " completed"); } Node unbounded_fanout( - std::shared_ptr&& ctx, - std::shared_ptr&& ch_in, - std::vector>&& chs_out + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out ) { - ShutdownAtExit c{ch_in}; - ShutdownAtExit c2{chs_out}; + ShutdownAtExit ch_in_shutdown{ch_in}; + ShutdownAtExit chs_out_shutdown{chs_out}; co_await ctx->executor()->schedule(); + auto& logger = ctx->logger(); + logger.debug("Scheduled unbounded fanout"); coro::mutex mtx; - coro::condition_variable data_ready; - coro::condition_variable request_data; - bool input_done{false}; - std::vector recv_messages; - - std::vector ch_next_idx(chs_out.size(), 0); - - std::vector tasks; - tasks.reserve(chs_out.size()); + coro::condition_variable data_ready; // notify send tasks to copy & send messages + coro::condition_variable + request_data; // notify this task to receive more data from the input channel + bool input_done{false}; // set to true when the input channel is fully consumed + std::deque recv_messages; // messages received from the input channel. We + // use a deque to avoid references being + // invalidated by reallocations. + std::vector ch_next_idx( + chs_out.size(), 0 + ); // next index to send for each channel + + coro::task_container tasks(ctx->executor()); for (size_t i = 0; i < chs_out.size(); i++) { - tasks.emplace_back(unbounded_fo_send_task( - *ctx, - i, - chs_out, - ch_next_idx, - mtx, - data_ready, - request_data, - input_done, - recv_messages - )); + RAPIDSMPF_EXPECTS( + tasks.start(unbounded_fo_send_task( + *ctx, + i, + chs_out, + ch_next_idx, + mtx, + data_ready, + request_data, + input_done, + recv_messages + )), + "failed to start send task" + ); } - size_t purge_idx = 0; + size_t purge_idx = 0; // index of the first message to purge + // input_done is only set by this task, so reading without lock is safe here while (!input_done) { { auto lock = co_await mtx.scoped_lock(); co_await request_data.wait(lock, [&] { + // return recv_messages.size() <= std::ranges::max(ch_next_idx); return std::ranges::any_of(ch_next_idx, [&](size_t next_idx) { - return recv_messages.size() >= next_idx; + return recv_messages.size() == next_idx; }); }); } @@ -168,9 +185,11 @@ Node unbounded_fanout( if (msg.empty()) { input_done = true; } else { + logger.debug("Received input", msg.sequence_number()); recv_messages.emplace_back(std::move(msg)); } } + // notify send_tasks to copy & send messages co_await data_ready.notify_all(); @@ -187,10 +206,10 @@ Node unbounded_fanout( // Note: there will be some messages to be purged after the loop exits, but we don't // need to do anything about them here - coro_results(co_await coro::when_all(std::move(tasks))); + co_await tasks.yield_until_empty(); + logger.debug("Unbounded fanout completed"); } - } // namespace Node fanout( @@ -201,11 +220,9 @@ Node fanout( ) { switch (policy) { case FanoutPolicy::BOUNDED: - co_await bounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); - break; + return bounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); case FanoutPolicy::UNBOUNDED: - co_await unbounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); - break; + return unbounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); default: RAPIDSMPF_FAIL("Unknown broadcast policy", std::invalid_argument); } diff --git a/cpp/src/streaming/core/leaf_node.cpp b/cpp/src/streaming/core/leaf_node.cpp index c2b380f1c..74f259126 100644 --- a/cpp/src/streaming/core/leaf_node.cpp +++ b/cpp/src/streaming/core/leaf_node.cpp @@ -18,7 +18,9 @@ Node push_to_channel( for (auto& msg : messages) { RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty", std::invalid_argument); - co_await ch_out->send(std::move(msg)); + RAPIDSMPF_EXPECTS( + co_await ch_out->send(std::move(msg)), "failed to send message" + ); } co_await ch_out->drain(ctx->executor()); } diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index f6c8df881..2fe9bdc1d 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -2,6 +2,7 @@ * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: Apache-2.0 */ +#include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include "base_streaming_fixture.hpp" @@ -31,19 +33,18 @@ namespace { std::vector make_int_inputs(int n) { std::vector inputs; inputs.reserve(n); - for (int i = 0; i < n; ++i) { - inputs.emplace_back( - i, - std::make_unique(i), + + Message::CopyCallback copy_cb = [](Message const& msg, MemoryReservation&) { + return Message{ + msg.sequence_number(), + std::make_unique(msg.get()), ContentDescription{}, - [](Message const& msg, MemoryReservation&) { - return Message{ - msg.sequence_number(), - std::make_unique(msg.get()), - ContentDescription{} - }; - } - ); + msg.copy_cb() + }; + }; + + for (int i = 0; i < n; ++i) { + inputs.emplace_back(i, std::make_unique(i), ContentDescription{}, copy_cb); } return inputs; } @@ -103,11 +104,13 @@ TEST_P(StreamingFanout, SinkPerChannel) { std::vector nodes; auto in = ctx->create_channel(); + std::cout << "Created input channel " << in.get() << std::endl; nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); std::vector> out_chs; for (int i = 0; i < num_out_chs; ++i) { out_chs.emplace_back(ctx->create_channel()); + std::cout << "Created output channel " << out_chs.back().get() << std::endl; } nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); @@ -135,13 +138,15 @@ TEST_P(StreamingFanout, SinkPerChannel) { namespace { enum class ConsumePolicy : uint8_t { - CHANNEL_ORDER, // consume messages in the order of the channels - MESSAGE_ORDER, // consume messages in the order of the messages + CHANNEL_ORDER, // consume all messages from a single channel before moving to the + // next + MESSAGE_ORDER, // consume messages from all channels before moving to the next + // message }; Node many_input_sink( - std::shared_ptr const& ctx, - std::vector>& chs, + std::shared_ptr ctx, + std::vector> chs, ConsumePolicy consume_policy, std::vector>& outs ) { @@ -159,10 +164,20 @@ Node many_input_sink( } } } else if (consume_policy == ConsumePolicy::MESSAGE_ORDER) { - // while (true) { - // auto msg = co_await chs[0]->receive(); - // outs[0].push_back(msg); - // } + std::unordered_set finished_chs{}; + while (finished_chs.size() < chs.size()) { + for (size_t i = 0; i < chs.size(); ++i) { + if (finished_chs.contains(i)) { + continue; + } + auto msg = co_await chs[i]->receive(); + if (msg.empty()) { + finished_chs.insert(i); + } else { + outs[i].emplace_back(std::move(msg)); + } + } + } } } @@ -170,7 +185,7 @@ Node many_input_sink( TEST_P(StreamingFanout, ManyInputSink_ChannelOrder) { if (policy == FanoutPolicy::BOUNDED) { - GTEST_SKIP() << "Bounded fanout does not support this consume policy"; + GTEST_SKIP() << "Bounded fanout does not support channel order"; } auto inputs = make_int_inputs(num_msgs); @@ -195,15 +210,52 @@ TEST_P(StreamingFanout, ManyInputSink_ChannelOrder) { run_streaming_pipeline(std::move(nodes)); } + std::vector expected(num_msgs); + std::iota(expected.begin(), expected.end(), 0); for (int c = 0; c < num_out_chs; ++c) { - // Validate sizes - EXPECT_EQ(outs[c].size(), static_cast(num_msgs)); + SCOPED_TRACE("channel " + std::to_string(c)); + std::vector actual; + actual.reserve(outs[c].size()); + std::ranges::transform(outs[c], std::back_inserter(actual), [](const Message& m) { + return m.get(); + }); + EXPECT_EQ(expected, actual); + } +} - // Validate ordering/content and that shallow copies share the same underlying - // object - for (int i = 0; i < num_msgs; ++i) { - SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); - EXPECT_EQ(outs[c][i].get(), i); +TEST_P(StreamingFanout, ManyInputSink_MessageOrder) { + auto inputs = make_int_inputs(num_msgs); + + std::vector> outs(num_out_chs); + { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + nodes.emplace_back( + many_input_sink(ctx, out_chs, ConsumePolicy::MESSAGE_ORDER, outs) + ); + + run_streaming_pipeline(std::move(nodes)); } -} + + std::vector expected(num_msgs); + std::iota(expected.begin(), expected.end(), 0); + for (int c = 0; c < num_out_chs; ++c) { + SCOPED_TRACE("channel " + std::to_string(c)); + std::vector actual; + actual.reserve(outs[c].size()); + std::ranges::transform(outs[c], std::back_inserter(actual), [](const Message& m) { + return m.get(); + }); + EXPECT_EQ(expected, actual); + } +} \ No newline at end of file From 80813320af970644cc377812dfd7c11f2fdb2ae7 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 7 Nov 2025 12:12:43 -0800 Subject: [PATCH 07/43] extending tests Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 181 +++++++++++++++++----------- cpp/tests/streaming/test_fanout.cpp | 127 +++++++++---------- 2 files changed, 167 insertions(+), 141 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 5c7e60dc5..46ff2e125 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -7,13 +7,17 @@ #include #include +#include #include #include +#include +#include #include namespace rapidsmpf::streaming::node { namespace { + /** * @brief Asynchronously send a message to multiple output channels. * @@ -35,6 +39,17 @@ Node send_to_channels( coro_results(co_await coro::when_all(std::move(tasks))); } +/** + * @brief Broadcast messages from one input channel to multiple output channels. + * + * @note Bounded fanout requires all the output channels to consume messages before the + * next message is sent/ consumed from the input channel. + * + * @param ctx The context to use. + * @param ch_in The input channel to receive messages from. + * @param chs_out The output channels to send messages to. + * @return A node representing the bounded fanout operation. + */ Node bounded_fanout( std::shared_ptr ctx, std::shared_ptr ch_in, @@ -60,67 +75,114 @@ Node bounded_fanout( logger.debug("Completed bounded fanout"); } +/** + * @brief State for the unbounded fanout. + */ +struct UnboundedFanoutState { + UnboundedFanoutState(size_t num_channels) : ch_next_idx(num_channels, 0) {} + + coro::mutex mtx; + // notify send tasks to copy & send messages + coro::condition_variable data_ready; + // notify this task to receive more data from the input channel + coro::condition_variable request_data; + // set to true when the input channel is fully consumed + bool input_done{false}; + // messages received from the input channel. We use a deque to avoid references being + // invalidated by reallocations. + std::deque recv_messages; + // next index to send for each channel + std::vector ch_next_idx; + // index of the first message to purge + size_t purge_idx{0}; +}; + +/** + * @brief Send messages to multiple output channels. + * + * @param ctx The context to use. + * @param idx The index of the task + * @param ch_out The output channel to send messages to. + * @param state The state of the unbounded fanout. + * @return A coroutine representing the task. + */ Node unbounded_fo_send_task( Context& ctx, size_t idx, - std::vector> const& chs_out, - std::vector& ch_next_idx, - coro::mutex& mtx, - coro::condition_variable& data_ready, - coro::condition_variable& request_data, - bool const& input_done, - std::deque const& recv_messages + std::shared_ptr& ch_out, + UnboundedFanoutState& state ) { - auto& logger = ctx.logger(); - ShutdownAtExit ch_shutdown{chs_out[idx]}; + ShutdownAtExit ch_shutdown{ch_out}; co_await ctx.executor()->schedule(); - size_t curr_recv_msg_sz; + auto& logger = ctx.logger(); + + size_t curr_recv_msg_sz = 0; // current size of the recv_messages deque while (true) { { - auto lock = co_await mtx.scoped_lock(); - co_await data_ready.wait(lock, [&] { + auto lock = co_await state.mtx.scoped_lock(); + co_await state.data_ready.wait(lock, [&] { // irrespective of input_done, update the end_idx to the total number of // messages - curr_recv_msg_sz = recv_messages.size(); - return input_done || ch_next_idx[idx] < curr_recv_msg_sz; + curr_recv_msg_sz = state.recv_messages.size(); + return state.input_done || state.ch_next_idx[idx] < curr_recv_msg_sz; }); - if (input_done && ch_next_idx[idx] == curr_recv_msg_sz) { + if (state.input_done && state.ch_next_idx[idx] == curr_recv_msg_sz) { // no more messages will be received, and all messages have been sent break; } } // now we can copy & send messages in indices [next_idx, end_idx) - for (size_t i = ch_next_idx[idx]; i < curr_recv_msg_sz; i++) { - auto const& msg = recv_messages[i]; + for (size_t i = state.ch_next_idx[idx]; i < curr_recv_msg_sz; i++) { + auto const& msg = state.recv_messages[i]; RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); // make reservations for each message so that it will fallback to host memory // if needed auto res = ctx.br()->reserve_or_fail(msg.copy_cost()); RAPIDSMPF_EXPECTS( - co_await chs_out[idx]->send(msg.copy(res)), "failed to send message" + co_await ch_out->send(msg.copy(res)), "failed to send message" ); } - logger.debug("sent ", idx, " [", ch_next_idx[idx], ", ", curr_recv_msg_sz, ")"); + logger.trace( + "sent ", idx, " [", state.ch_next_idx[idx], ", ", curr_recv_msg_sz, ")" + ); // now next_idx can be updated to end_idx, and if !input_done, we need to request // parent task for more data - auto lock = co_await mtx.scoped_lock(); - ch_next_idx[idx] = curr_recv_msg_sz; - if (input_done && ch_next_idx[idx] == recv_messages.size()) { - break; - } else { - lock.unlock(); - co_await request_data.notify_one(); + auto lock = co_await state.mtx.scoped_lock(); + state.ch_next_idx[idx] = curr_recv_msg_sz; + if (state.ch_next_idx[idx] == state.recv_messages.size()) { + if (state.input_done) { + break; // no more messages will be received, and all messages have been + // sent + } else { + // request more data from the input channel + lock.unlock(); + co_await state.request_data.notify_one(); + } } } - co_await chs_out[idx]->drain(ctx.executor()); + co_await ch_out->drain(ctx.executor()); logger.debug("Send task ", idx, " completed"); } +/** + * @brief Broadcast messages from one input channel to multiple output channels. + * + * This is a general purpose implementation which can support consuming messages by any + * channel. A consumer node can decide to consume all messages from a single channel + * before moving to the next channel, or it can consume messages from all channels before + * moving to the next message. When a message has been sent to all output channels, it is + * purged from the internal deque. + * + * @param ctx The context to use. + * @param ch_in The input channel to receive messages from. + * @param chs_out The output channels to send messages to. + * @return A node representing the unbounded fanout operation. + */ Node unbounded_fanout( std::shared_ptr ctx, std::shared_ptr ch_in, @@ -133,46 +195,24 @@ Node unbounded_fanout( auto& logger = ctx->logger(); logger.debug("Scheduled unbounded fanout"); - coro::mutex mtx; - coro::condition_variable data_ready; // notify send tasks to copy & send messages - coro::condition_variable - request_data; // notify this task to receive more data from the input channel - bool input_done{false}; // set to true when the input channel is fully consumed - std::deque recv_messages; // messages received from the input channel. We - // use a deque to avoid references being - // invalidated by reallocations. - std::vector ch_next_idx( - chs_out.size(), 0 - ); // next index to send for each channel + UnboundedFanoutState state(chs_out.size()); + // start send tasks for each output channel coro::task_container tasks(ctx->executor()); for (size_t i = 0; i < chs_out.size(); i++) { RAPIDSMPF_EXPECTS( - tasks.start(unbounded_fo_send_task( - *ctx, - i, - chs_out, - ch_next_idx, - mtx, - data_ready, - request_data, - input_done, - recv_messages - )), + tasks.start(unbounded_fo_send_task(*ctx, i, chs_out[i], state)), "failed to start send task" ); } - size_t purge_idx = 0; // index of the first message to purge - // input_done is only set by this task, so reading without lock is safe here - while (!input_done) { + while (!state.input_done) { { - auto lock = co_await mtx.scoped_lock(); - co_await request_data.wait(lock, [&] { - // return recv_messages.size() <= std::ranges::max(ch_next_idx); - return std::ranges::any_of(ch_next_idx, [&](size_t next_idx) { - return recv_messages.size() == next_idx; + auto lock = co_await state.mtx.scoped_lock(); + co_await state.request_data.wait(lock, [&] { + return std::ranges::any_of(state.ch_next_idx, [&](size_t next_idx) { + return state.recv_messages.size() == next_idx; }); }); } @@ -181,27 +221,32 @@ Node unbounded_fanout( auto msg = co_await ch_in->receive(); { // relock mtx to update input_done/ recv_messages - auto lock = co_await mtx.scoped_lock(); + auto lock = co_await state.mtx.scoped_lock(); if (msg.empty()) { - input_done = true; + state.input_done = true; } else { - logger.debug("Received input", msg.sequence_number()); - recv_messages.emplace_back(std::move(msg)); + logger.trace("Received input", msg.sequence_number()); + state.recv_messages.emplace_back(std::move(msg)); } } - + // notify send_tasks to copy & send messages - co_await data_ready.notify_all(); + co_await state.data_ready.notify_all(); - // purge completed send_tasks + // purge completed send_tasks. This will reset the messages to empty, so that they + // release the memory, however the deque is not resized. This guarantees that the + // indices are not invalidated. // intentionally not locking the mtx here, because we only need to know a // lower-bound on the last completed idx (ch_next_idx values are monotonically // increasing) - size_t last_completed_idx = std::ranges::min(ch_next_idx); - while (purge_idx + 1 < last_completed_idx) { - recv_messages[purge_idx].reset(); - purge_idx++; + size_t last_completed_idx = std::ranges::min(state.ch_next_idx); + while (state.purge_idx + 1 < last_completed_idx) { + state.recv_messages[state.purge_idx].reset(); + state.purge_idx++; } + logger.trace( + "recv_messages active size: ", state.recv_messages.size() - state.purge_idx + ); } // Note: there will be some messages to be purged after the loop exits, but we don't diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index 2fe9bdc1d..2ce396568 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -25,8 +25,6 @@ using namespace rapidsmpf::streaming; namespace node = rapidsmpf::streaming::node; using rapidsmpf::streaming::node::FanoutPolicy; -namespace { - /** * @brief Helper to make a sequence of Message with values [0, n). */ @@ -60,8 +58,6 @@ std::string policy_to_string(FanoutPolicy policy) { } } -} // namespace - class StreamingFanout : public BaseStreamingFixture, public ::testing::WithParamInterface> { @@ -104,13 +100,11 @@ TEST_P(StreamingFanout, SinkPerChannel) { std::vector nodes; auto in = ctx->create_channel(); - std::cout << "Created input channel " << in.get() << std::endl; nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); std::vector> out_chs; for (int i = 0; i < num_out_chs; ++i) { out_chs.emplace_back(ctx->create_channel()); - std::cout << "Created output channel " << out_chs.back().get() << std::endl; } nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); @@ -135,8 +129,6 @@ TEST_P(StreamingFanout, SinkPerChannel) { } } -namespace { - enum class ConsumePolicy : uint8_t { CHANNEL_ORDER, // consume all messages from a single channel before moving to the // next @@ -181,81 +173,70 @@ Node many_input_sink( } } -} // namespace +struct ManyInputSinkStreamingFanout : public StreamingFanout { + void run(ConsumePolicy consume_policy) { + auto inputs = make_int_inputs(num_msgs); -TEST_P(StreamingFanout, ManyInputSink_ChannelOrder) { - if (policy == FanoutPolicy::BOUNDED) { - GTEST_SKIP() << "Bounded fanout does not support channel order"; - } + std::vector> outs(num_out_chs); + { + std::vector nodes; - auto inputs = make_int_inputs(num_msgs); + auto in = ctx->create_channel(); + nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); - std::vector> outs(num_out_chs); - { - std::vector nodes; + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } - auto in = ctx->create_channel(); - nodes.push_back(node::push_to_channel(ctx, in, std::move(inputs))); + nodes.push_back(node::fanout(ctx, in, out_chs, policy)); - std::vector> out_chs; - for (int i = 0; i < num_out_chs; ++i) { - out_chs.emplace_back(ctx->create_channel()); - } + nodes.push_back(many_input_sink(ctx, out_chs, consume_policy, outs)); - nodes.push_back(node::fanout(ctx, in, out_chs, policy)); - - nodes.push_back(many_input_sink(ctx, out_chs, ConsumePolicy::CHANNEL_ORDER, outs) - ); + run_streaming_pipeline(std::move(nodes)); + } - run_streaming_pipeline(std::move(nodes)); + std::vector expected(num_msgs); + std::iota(expected.begin(), expected.end(), 0); + for (int c = 0; c < num_out_chs; ++c) { + SCOPED_TRACE("channel " + std::to_string(c)); + std::vector actual; + actual.reserve(outs[c].size()); + std::ranges::transform( + outs[c], std::back_inserter(actual), [](const Message& m) { + return m.get(); + } + ); + EXPECT_EQ(expected, actual); + } } +}; - std::vector expected(num_msgs); - std::iota(expected.begin(), expected.end(), 0); - for (int c = 0; c < num_out_chs; ++c) { - SCOPED_TRACE("channel " + std::to_string(c)); - std::vector actual; - actual.reserve(outs[c].size()); - std::ranges::transform(outs[c], std::back_inserter(actual), [](const Message& m) { - return m.get(); - }); - EXPECT_EQ(expected, actual); +INSTANTIATE_TEST_SUITE_P( + ManyInputSinkStreamingFanout, + ManyInputSinkStreamingFanout, + ::testing::Combine( + ::testing::Values(FanoutPolicy::BOUNDED, FanoutPolicy::UNBOUNDED), + ::testing::Values(1, 2, 4), // number of threads + ::testing::Values(1, 2, 4), // number of output channels + ::testing::Values(1, 10, 100) // number of messages + ), + [](testing::TestParamInfo const& info) { + return "policy_" + policy_to_string(std::get<0>(info.param)) + "_nthreads_" + + std::to_string(std::get<1>(info.param)) + "_nch_out_" + + std::to_string(std::get<2>(info.param)) + "_nmsgs_" + + std::to_string(std::get<3>(info.param)); } -} - -TEST_P(StreamingFanout, ManyInputSink_MessageOrder) { - auto inputs = make_int_inputs(num_msgs); - - std::vector> outs(num_out_chs); - { - std::vector nodes; - - auto in = ctx->create_channel(); - nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); - - std::vector> out_chs; - for (int i = 0; i < num_out_chs; ++i) { - out_chs.emplace_back(ctx->create_channel()); - } - - nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); - - nodes.emplace_back( - many_input_sink(ctx, out_chs, ConsumePolicy::MESSAGE_ORDER, outs) - ); +); - run_streaming_pipeline(std::move(nodes)); +TEST_P(ManyInputSinkStreamingFanout, ChannelOrder) { + if (policy == FanoutPolicy::BOUNDED) { + GTEST_SKIP() << "Bounded fanout does not support channel order"; } - std::vector expected(num_msgs); - std::iota(expected.begin(), expected.end(), 0); - for (int c = 0; c < num_out_chs; ++c) { - SCOPED_TRACE("channel " + std::to_string(c)); - std::vector actual; - actual.reserve(outs[c].size()); - std::ranges::transform(outs[c], std::back_inserter(actual), [](const Message& m) { - return m.get(); - }); - EXPECT_EQ(expected, actual); - } -} \ No newline at end of file + EXPECT_NO_FATAL_FAILURE(run(ConsumePolicy::CHANNEL_ORDER)); +} + +TEST_P(ManyInputSinkStreamingFanout, MessageOrder) { + EXPECT_NO_FATAL_FAILURE(run(ConsumePolicy::MESSAGE_ORDER)); +} From 37ea50d58c6290e1a0e5831b3d4edc453ec8e807 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 7 Nov 2025 12:42:25 -0800 Subject: [PATCH 08/43] minor changes Signed-off-by: niranda perera --- cpp/tests/streaming/test_fanout.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index 2ce396568..988002d87 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -156,17 +156,18 @@ Node many_input_sink( } } } else if (consume_policy == ConsumePolicy::MESSAGE_ORDER) { - std::unordered_set finished_chs{}; - while (finished_chs.size() < chs.size()) { - for (size_t i = 0; i < chs.size(); ++i) { - if (finished_chs.contains(i)) { - continue; - } - auto msg = co_await chs[i]->receive(); + std::unordered_set active_chs{}; + for (size_t i = 0; i < chs.size(); ++i) { + active_chs.insert(i); + } + while (!active_chs.empty()) { + for (auto it = active_chs.begin(); it != active_chs.end();){ + auto msg = co_await chs[*it]->receive(); if (msg.empty()) { - finished_chs.insert(i); + it = active_chs.erase(it); } else { - outs[i].emplace_back(std::move(msg)); + outs[*it].emplace_back(std::move(msg)); + it++; } } } From ff962c7be26480b0d39ea2793858d696e7df4cb7 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 7 Nov 2025 14:29:29 -0800 Subject: [PATCH 09/43] add python bindings Signed-off-by: niranda perera --- cpp/tests/streaming/test_fanout.cpp | 2 +- .../rapidsmpf/streaming/core/CMakeLists.txt | 4 +- .../rapidsmpf/streaming/core/fanout.pxd | 25 +++ .../rapidsmpf/streaming/core/fanout.pyi | 21 +++ .../rapidsmpf/streaming/core/fanout.pyx | 111 ++++++++++++ .../rapidsmpf/tests/streaming/test_fanout.py | 163 ++++++++++++++++++ 6 files changed, 323 insertions(+), 3 deletions(-) create mode 100644 python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd create mode 100644 python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyi create mode 100644 python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx create mode 100644 python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index 988002d87..0bff6ee33 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -161,7 +161,7 @@ Node many_input_sink( active_chs.insert(i); } while (!active_chs.empty()) { - for (auto it = active_chs.begin(); it != active_chs.end();){ + for (auto it = active_chs.begin(); it != active_chs.end();) { auto msg = co_await chs[*it]->receive(); if (msg.empty()) { it = active_chs.erase(it); diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/CMakeLists.txt b/python/rapidsmpf/rapidsmpf/streaming/core/CMakeLists.txt index 2725bf6dd..8e990cbc4 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/core/CMakeLists.txt +++ b/python/rapidsmpf/rapidsmpf/streaming/core/CMakeLists.txt @@ -5,8 +5,8 @@ # cmake-format: on # ================================================================================= -set(cython_modules channel.pyx context.pyx leaf_node.pyx lineariser.pyx message.pyx node.pyx - utilities.pyx +set(cython_modules channel.pyx context.pyx fanout.pyx leaf_node.pyx lineariser.pyx message.pyx + node.pyx utilities.pyx ) rapids_cython_create_modules( diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd new file mode 100644 index 000000000..3f74bbf9e --- /dev/null +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from libc.stdint cimport uint8_t +from libcpp.memory cimport shared_ptr +from libcpp.vector cimport vector + +from rapidsmpf.streaming.core.channel cimport cpp_Channel +from rapidsmpf.streaming.core.context cimport cpp_Context +from rapidsmpf.streaming.core.node cimport cpp_Node + + +cdef extern from "" namespace "rapidsmpf::streaming::node" nogil: + cdef enum class cpp_FanoutPolicy "rapidsmpf::streaming::node::FanoutPolicy" (uint8_t): + BOUNDED "rapidsmpf::streaming::node::FanoutPolicy::BOUNDED" + UNBOUNDED "rapidsmpf::streaming::node::FanoutPolicy::UNBOUNDED" + + cdef cpp_Node cpp_fanout \ + "rapidsmpf::streaming::node::fanout"( + shared_ptr[cpp_Context] ctx, + shared_ptr[cpp_Channel] ch_in, + vector[shared_ptr[cpp_Channel]] chs_out, + cpp_FanoutPolicy policy + ) except + + diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyi b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyi new file mode 100644 index 000000000..8f8864620 --- /dev/null +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyi @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from enum import IntEnum + +from rapidsmpf.streaming.core.channel import Channel +from rapidsmpf.streaming.core.context import Context +from rapidsmpf.streaming.core.node import CppNode + +class FanoutPolicy(IntEnum): + BOUNDED = ... + UNBOUNDED = ... + +def fanout( + ctx: Context, + ch_in: Channel, + chs_out: list[Channel], + policy: FanoutPolicy, +) -> CppNode: ... diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx new file mode 100644 index 000000000..831d437f5 --- /dev/null +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from enum import IntEnum + +from libcpp.memory cimport make_unique, shared_ptr +from libcpp.utility cimport move +from libcpp.vector cimport vector + +from rapidsmpf.streaming.core.channel cimport Channel, cpp_Channel +from rapidsmpf.streaming.core.context cimport Context, cpp_Context +from rapidsmpf.streaming.core.node cimport CppNode, cpp_Node + + +class FanoutPolicy(IntEnum): + """ + Fanout policy controlling how messages are propagated. + + Attributes + ---------- + BOUNDED : int + Process messages as they arrive and immediately forward them. + Messages are forwarded as soon as they are received from the input channel. + The next message is not processed until all output channels have completed + sending the current one, ensuring backpressure and synchronized flow. + UNBOUNDED : int + Forward messages without enforcing backpressure. + In this mode, messages may be accumulated internally before being + broadcast, or they may be forwarded immediately depending on the + implementation and downstream consumption rate. + + This mode disables coordinated backpressure between outputs, allowing + consumers to process at independent rates, but can lead to unbounded + buffering and increased memory usage. + + Note: Consumers might not receive any messages until *all* upstream + messages have been sent, depending on the implementation and buffering + strategy. + """ + BOUNDED = cpp_FanoutPolicy.BOUNDED + UNBOUNDED = cpp_FanoutPolicy.UNBOUNDED + + +def fanout(Context ctx, Channel ch_in, list chs_out, policy): + """ + Broadcast messages from one input channel to multiple output channels. + + The node continuously receives messages from the input channel and forwards + them to all output channels according to the selected fanout policy. + + Each output channel receives a shallow copy of the same message; no payload + data is duplicated. All copies share the same underlying payload, ensuring + zero-copy broadcast semantics. + + Parameters + ---------- + ctx : Context + The node context to use. + ch_in : Channel + Input channel from which messages are received. + chs_out : list[Channel] + Output channels to which messages are broadcast. + policy : FanoutPolicy + The fanout strategy to use (see FanoutPolicy). + + Returns + ------- + CppNode + Streaming node representing the fanout operation. + + Raises + ------ + ValueError + If an unknown fanout policy is specified. + + Notes + ----- + Since messages are shallow-copied, releasing a payload (``release()``) + is only valid on messages that hold exclusive ownership of the payload. + + Examples + -------- + >>> import rapidsmpf.streaming.core as streaming + >>> ctx = streaming.Context(...) + >>> ch_in = ctx.create_channel() + >>> ch_out1 = ctx.create_channel() + >>> ch_out2 = ctx.create_channel() + >>> node = streaming.fanout( + ... ctx, ch_in, [ch_out1, ch_out2], streaming.FanoutPolicy.BOUNDED + ... ) + """ + # Validate policy + if not isinstance(policy, (FanoutPolicy, int)): + raise TypeError(f"policy must be a FanoutPolicy enum value, got {type(policy)}") + + cdef vector[shared_ptr[cpp_Channel]] _chs_out + owner = [] + for ch_out in chs_out: + if not isinstance(ch_out, Channel): + raise TypeError("All elements in chs_out must be Channel instances") + owner.append(ch_out) + _chs_out.push_back((ch_out)._handle) + + cdef cpp_FanoutPolicy _policy = (policy) + cdef cpp_Node _ret + with nogil: + _ret = cpp_fanout( + ctx._handle, ch_in._handle, move(_chs_out), _policy + ) + return CppNode.from_handle(make_unique[cpp_Node](move(_ret)), owner) + diff --git a/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py b/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py new file mode 100644 index 000000000..a397011e0 --- /dev/null +++ b/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for streaming fanout node.""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING + +import pytest + +import cudf + +from rapidsmpf.streaming.core.fanout import FanoutPolicy, fanout +from rapidsmpf.streaming.core.leaf_node import pull_from_channel, push_to_channel +from rapidsmpf.streaming.core.message import Message +from rapidsmpf.streaming.core.node import run_streaming_pipeline +from rapidsmpf.streaming.cudf.table_chunk import TableChunk +from rapidsmpf.testing import assert_eq +from rapidsmpf.utils.cudf import cudf_to_pylibcudf_table + +if TYPE_CHECKING: + from rmm.pylibrmm.stream import Stream + + from rapidsmpf.streaming.core.channel import Channel + from rapidsmpf.streaming.core.context import Context + + +@pytest.mark.parametrize("policy", [FanoutPolicy.BOUNDED, FanoutPolicy.UNBOUNDED]) +def test_fanout_basic(context: Context, stream: Stream, policy: FanoutPolicy) -> None: + """Test basic fanout functionality with multiple output channels.""" + # Create channels + ch_in: Channel[TableChunk] = context.create_channel() + ch_out1: Channel[TableChunk] = context.create_channel() + ch_out2: Channel[TableChunk] = context.create_channel() + + # Create test messages + messages = [] + for i in range(5): + df = cudf.DataFrame( + {"a": [i, i + 1, i + 2], "b": [i * 10, i * 10 + 1, i * 10 + 2]} + ) + chunk = TableChunk.from_pylibcudf_table( + cudf_to_pylibcudf_table(df), stream, exclusive_view=False + ) + messages.append(Message(i, chunk)) + + # Create nodes + push_node = push_to_channel(context, ch_in, messages) + fanout_node = fanout(context, ch_in, [ch_out1, ch_out2], policy) + pull_node1, output1 = pull_from_channel(context, ch_out1) + pull_node2, output2 = pull_from_channel(context, ch_out2) + + # Run pipeline + with ThreadPoolExecutor(max_workers=1) as executor: + run_streaming_pipeline( + nodes=[push_node, fanout_node, pull_node1, pull_node2], + py_executor=executor, + ) + + # Verify results + results1 = output1.release() + results2 = output2.release() + + assert len(results1) == 5, f"Expected 5 messages in output1, got {len(results1)}" + assert len(results2) == 5, f"Expected 5 messages in output2, got {len(results2)}" + + # Check that both outputs received the same sequence numbers and data + for i in range(5): + assert results1[i].sequence_number == i + assert results2[i].sequence_number == i + + chunk1 = TableChunk.from_message(results1[i]) + chunk2 = TableChunk.from_message(results2[i]) + + # Expected data + expected_df = cudf.DataFrame( + {"a": [i, i + 1, i + 2], "b": [i * 10, i * 10 + 1, i * 10 + 2]} + ) + expected_table = cudf_to_pylibcudf_table(expected_df) + + # Verify data is correct + assert_eq(chunk1.table_view(), expected_table) + assert_eq(chunk2.table_view(), expected_table) + + +@pytest.mark.parametrize("num_outputs", [1, 3, 5]) +def test_fanout_multiple_outputs( + context: Context, stream: Stream, num_outputs: int +) -> None: + """Test fanout with varying numbers of output channels.""" + # Create channels + ch_in: Channel[TableChunk] = context.create_channel() + chs_out: list[Channel[TableChunk]] = [ + context.create_channel() for _ in range(num_outputs) + ] + + # Create test messages + messages = [] + for i in range(3): + df = cudf.DataFrame({"x": [i * 10, i * 10 + 1]}) + chunk = TableChunk.from_pylibcudf_table( + cudf_to_pylibcudf_table(df), stream, exclusive_view=False + ) + messages.append(Message(i, chunk)) + + # Create nodes + push_node = push_to_channel(context, ch_in, messages) + fanout_node = fanout(context, ch_in, chs_out, FanoutPolicy.BOUNDED) + pull_nodes = [] + outputs = [] + for ch_out in chs_out: + pull_node, output = pull_from_channel(context, ch_out) + pull_nodes.append(pull_node) + outputs.append(output) + + # Run pipeline + with ThreadPoolExecutor(max_workers=1) as executor: + run_streaming_pipeline( + nodes=[push_node, fanout_node, *pull_nodes], + py_executor=executor, + ) + + # Verify all outputs received the messages + for output_idx, output in enumerate(outputs): + results = output.release() + assert len(results) == 3, ( + f"Output {output_idx}: Expected 3 messages, got {len(results)}" + ) + for i in range(3): + assert results[i].sequence_number == i + + +def test_fanout_empty_outputs(context: Context, stream: Stream) -> None: + """Test fanout with empty output list raises appropriate error or handles gracefully.""" + ch_in: Channel[TableChunk] = context.create_channel() + + # This should work but produce no output (or could raise ValueError) + # The C++ implementation may handle this differently + fanout_node = fanout(context, ch_in, [], FanoutPolicy.BOUNDED) + + # Create a simple message + df = cudf.DataFrame({"a": [1, 2, 3]}) + chunk = TableChunk.from_pylibcudf_table( + cudf_to_pylibcudf_table(df), stream, exclusive_view=False + ) + messages = [Message(0, chunk)] + push_node = push_to_channel(context, ch_in, messages) + + # Run pipeline - should complete without error + with ThreadPoolExecutor(max_workers=1) as executor: + run_streaming_pipeline( + nodes=[push_node, fanout_node], + py_executor=executor, + ) + + +def test_fanout_policy_enum() -> None: + """Test that FanoutPolicy enum has correct values.""" + assert FanoutPolicy.BOUNDED == 0 + assert FanoutPolicy.UNBOUNDED == 1 + assert len(FanoutPolicy) == 2 From 0c7d6a397fc6677fd7b8c7023f23cbb03624df7d Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 7 Nov 2025 14:56:01 -0800 Subject: [PATCH 10/43] precommit Signed-off-by: niranda perera --- python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd | 7 ++++--- python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx | 7 +++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd index 3f74bbf9e..50ea90b32 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd @@ -10,8 +10,10 @@ from rapidsmpf.streaming.core.context cimport cpp_Context from rapidsmpf.streaming.core.node cimport cpp_Node -cdef extern from "" namespace "rapidsmpf::streaming::node" nogil: - cdef enum class cpp_FanoutPolicy "rapidsmpf::streaming::node::FanoutPolicy" (uint8_t): +cdef extern from "" \ + namespace "rapidsmpf::streaming::node" nogil: + cdef enum class cpp_FanoutPolicy \ + "rapidsmpf::streaming::node::FanoutPolicy" (uint8_t): BOUNDED "rapidsmpf::streaming::node::FanoutPolicy::BOUNDED" UNBOUNDED "rapidsmpf::streaming::node::FanoutPolicy::UNBOUNDED" @@ -22,4 +24,3 @@ cdef extern from "" namespace "rapidsmpf::s vector[shared_ptr[cpp_Channel]] chs_out, cpp_FanoutPolicy policy ) except + - diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx index 831d437f5..889318a6b 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx @@ -7,8 +7,8 @@ from libcpp.memory cimport make_unique, shared_ptr from libcpp.utility cimport move from libcpp.vector cimport vector -from rapidsmpf.streaming.core.channel cimport Channel, cpp_Channel -from rapidsmpf.streaming.core.context cimport Context, cpp_Context +from rapidsmpf.streaming.core.channel cimport Channel +from rapidsmpf.streaming.core.context cimport Context from rapidsmpf.streaming.core.node cimport CppNode, cpp_Node @@ -92,7 +92,7 @@ def fanout(Context ctx, Channel ch_in, list chs_out, policy): # Validate policy if not isinstance(policy, (FanoutPolicy, int)): raise TypeError(f"policy must be a FanoutPolicy enum value, got {type(policy)}") - + cdef vector[shared_ptr[cpp_Channel]] _chs_out owner = [] for ch_out in chs_out: @@ -108,4 +108,3 @@ def fanout(Context ctx, Channel ch_in, list chs_out, policy): ctx._handle, ch_in._handle, move(_chs_out), _policy ) return CppNode.from_handle(make_unique[cpp_Node](move(_ret)), owner) - From 3b124ca6d920f82b06e02a828d351fb2a883328f Mon Sep 17 00:00:00 2001 From: Niranda Perera Date: Sun, 9 Nov 2025 08:00:33 -0800 Subject: [PATCH 11/43] Update cpp/src/streaming/core/fanout.cpp Co-authored-by: Mads R. B. Kristensen --- cpp/src/streaming/core/fanout.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 46ff2e125..dae60c673 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -43,7 +43,7 @@ Node send_to_channels( * @brief Broadcast messages from one input channel to multiple output channels. * * @note Bounded fanout requires all the output channels to consume messages before the - * next message is sent/ consumed from the input channel. + * next message is sent/consumed from the input channel. * * @param ctx The context to use. * @param ch_in The input channel to receive messages from. From ea945ca7cf11de12cdb1df31fcff05d69719bc92 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 10 Nov 2025 11:25:59 -0800 Subject: [PATCH 12/43] adding lower mem types Signed-off-by: niranda perera --- cpp/include/rapidsmpf/buffer/buffer.hpp | 16 ++++++- .../rapidsmpf/buffer/content_description.hpp | 39 +++++++++++++++++ cpp/include/rapidsmpf/buffer/resource.hpp | 11 +++++ .../rapidsmpf/streaming/core/fanout.hpp | 11 +---- cpp/src/buffer/resource.cpp | 9 +++- cpp/src/streaming/core/fanout.cpp | 22 +++++++--- cpp/tests/streaming/test_message.cpp | 42 +++++++++++++++++++ 7 files changed, 133 insertions(+), 17 deletions(-) diff --git a/cpp/include/rapidsmpf/buffer/buffer.hpp b/cpp/include/rapidsmpf/buffer/buffer.hpp index 103e1affd..13d9fed4f 100644 --- a/cpp/include/rapidsmpf/buffer/buffer.hpp +++ b/cpp/include/rapidsmpf/buffer/buffer.hpp @@ -24,7 +24,7 @@ namespace rapidsmpf { /// @brief Enum representing the type of memory. -enum class MemoryType : int { +enum class MemoryType : uint8_t { DEVICE = 0, ///< Device memory HOST = 1 ///< Host memory }; @@ -36,6 +36,20 @@ constexpr MemoryType LowestSpillType = MemoryType::HOST; /// @note Ensure that this array is always sorted in decreasing order of preference. constexpr std::array MEMORY_TYPES{{MemoryType::DEVICE, MemoryType::HOST}}; +/** + * @brief Get the lower memory types inclusive of the given memory type. + * + * @param mem_type The memory type to get the lower memory types for. + * @return A span of the lower memory types inclusive of the given memory type. + */ +constexpr std::span LowerMemoryTypesInclusive( + MemoryType mem_type +) noexcept { + return std::span{ + MEMORY_TYPES.begin() + static_cast(mem_type), MEMORY_TYPES.end() + }; +} + /** * @brief Buffer representing device or host memory. * diff --git a/cpp/include/rapidsmpf/buffer/content_description.hpp b/cpp/include/rapidsmpf/buffer/content_description.hpp index 646a876aa..5a34569ef 100644 --- a/cpp/include/rapidsmpf/buffer/content_description.hpp +++ b/cpp/include/rapidsmpf/buffer/content_description.hpp @@ -3,7 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include #include +#include #include #include @@ -64,6 +66,10 @@ class ContentDescription { std::pair> constexpr explicit ContentDescription(Range&& sizes, Spillable spillable) : spillable_(spillable == Spillable::YES) { + RAPIDSMPF_EXPECTS( + std::ranges::size(sizes) > 0, + "ContentDescription must have at least one memory type" + ); content_sizes_.fill(0); for (auto&& [mem_type, size] : sizes) { auto idx = static_cast(mem_type); @@ -104,6 +110,39 @@ class ContentDescription { return content_sizes_[static_cast(mem_type)]; } + /** + * @brief Get the highest memory type that has a non-zero size. + * + * @return The highest memory type that has a non-zero size. If no memory type has a + * non-zero size, returns `MemoryType::HOST`. + */ + constexpr MemoryType highest_memory_type_set() const noexcept { + auto it = std::ranges::find_if(content_sizes_, [](auto const& size) { + return size > 0; + }); + return it == content_sizes_.end() + ? MemoryType::HOST + : static_cast(std::distance(content_sizes_.begin(), it)); + } + + /** + * @brief Get the lowest memory type that has a non-zero size. + * + * @return The lowest memory type that has a non-zero size. If no memory type has a + * non-zero size, returns `MemoryType::HOST`. + */ + constexpr MemoryType lowest_memory_type_set() const noexcept { + auto it = std::ranges::find_if( + std::ranges::reverse_view(content_sizes_), + [](auto const& size) { return size > 0; } + ); + return it == content_sizes_.rend() + ? MemoryType::HOST + : static_cast( + std::distance(content_sizes_.begin(), it.base()) - 1 + ); + } + /** * @brief Get the total content size across all memory types. * diff --git a/cpp/include/rapidsmpf/buffer/resource.hpp b/cpp/include/rapidsmpf/buffer/resource.hpp index ebf5f490a..ddd72f9a9 100644 --- a/cpp/include/rapidsmpf/buffer/resource.hpp +++ b/cpp/include/rapidsmpf/buffer/resource.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -265,6 +266,16 @@ class BufferResource { size_t size, std::optional mem_type = std::nullopt ); + /** + * @brief Make a memory reservation or fail. + * + * @param size The size of the buffer to allocate. + * @param mem_types The memory types to try to allocate the buffer from. + * @return A memory reservation. + * @throws std::runtime_error if no memory reservation was made. + */ + MemoryReservation reserve_or_fail(size_t size, std::span mem_types); + /** * @brief Consume a portion of the reserved memory. * diff --git a/cpp/include/rapidsmpf/streaming/core/fanout.hpp b/cpp/include/rapidsmpf/streaming/core/fanout.hpp index 01e027556..59a64cc6d 100644 --- a/cpp/include/rapidsmpf/streaming/core/fanout.hpp +++ b/cpp/include/rapidsmpf/streaming/core/fanout.hpp @@ -34,10 +34,6 @@ enum class FanoutPolicy : uint8_t { * This mode disables coordinated backpressure between outputs, allowing * consumers to process at independent rates, but can lead to unbounded * buffering and increased memory usage. - * - * @note Consumers might not receive any messages until *all* upstream - * messages have been sent, depending on the implementation and buffering - * strategy. */ UNBOUNDED, }; @@ -49,9 +45,7 @@ enum class FanoutPolicy : uint8_t { * them to all output channels according to the selected fanout policy, see * ::FanoutPolicy. * - * Each output channel receives a shallow copy of the same message; no payload - * data is duplicated. All copies share the same underlying payload, ensuring - * zero-copy broadcast semantics. + * Each output channel receives a deep copy of the same message. * * @param ctx The node context to use. * @param ch_in Input channel from which messages are received. @@ -61,9 +55,6 @@ enum class FanoutPolicy : uint8_t { * @return Streaming node representing the fanout operation. * * @throws std::invalid_argument If an unknown fanout policy is specified. - * - * @note Since messages are shallow-copied, releasing a payload (`release()`) - * is only valid on messages that hold exclusive ownership of the payload. */ Node fanout( std::shared_ptr ctx, diff --git a/cpp/src/buffer/resource.cpp b/cpp/src/buffer/resource.cpp index b35c859f1..b040f5ccb 100644 --- a/cpp/src/buffer/resource.cpp +++ b/cpp/src/buffer/resource.cpp @@ -105,7 +105,14 @@ MemoryReservation BufferResource::reserve_or_fail( } // try to allocate data buffer from memory types in order [DEVICE, HOST] - for (auto mem_type : MEMORY_TYPES) { + return reserve_or_fail(size, MEMORY_TYPES); +} + +MemoryReservation BufferResource::reserve_or_fail( + size_t size, std::span mem_types +) { + // try to allocate data buffer from memory types in order [DEVICE, HOST] + for (auto const& mem_type : mem_types) { auto [res, _] = reserve(mem_type, size, false); if (res.size() == size) { return std::move(res); diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 46ff2e125..be1445898 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -30,10 +31,13 @@ Node send_to_channels( ) { std::vector> tasks; tasks.reserve(chs_out.size()); + + auto try_memory_types = + LowerMemoryTypesInclusive(msg.content_description().lowest_memory_type_set()); for (auto& ch_out : chs_out) { // do a reservation for each copy, so that it will fallback to host memory if // needed - auto res = ctx->br()->reserve_or_fail(msg.copy_cost()); + auto res = ctx->br()->reserve_or_fail(msg.copy_cost(), try_memory_types); tasks.push_back(ch_out->send(msg.copy(res))); } coro_results(co_await coro::when_all(std::move(tasks))); @@ -69,9 +73,12 @@ Node bounded_fanout( logger.debug("Sent message ", msg.sequence_number()); } - for (auto& ch : chs_out) { - co_await ch->drain(ctx->executor()); - } + std::vector> drain_tasks; + drain_tasks.reserve(chs_out.size()); + std::ranges::for_each(chs_out, [&](auto& ch) { + drain_tasks.emplace_back(ch->drain(ctx->executor())); + }); + coro_results(co_await coro::when_all(std::move(drain_tasks))); logger.debug("Completed bounded fanout"); } @@ -140,7 +147,12 @@ Node unbounded_fo_send_task( // make reservations for each message so that it will fallback to host memory // if needed - auto res = ctx.br()->reserve_or_fail(msg.copy_cost()); + + auto try_memory_types = LowerMemoryTypesInclusive( + msg.content_description().lowest_memory_type_set() + ); + + auto res = ctx.br()->reserve_or_fail(msg.copy_cost(), try_memory_types); RAPIDSMPF_EXPECTS( co_await ch_out->send(msg.copy(res)), "failed to send message" ); diff --git a/cpp/tests/streaming/test_message.cpp b/cpp/tests/streaming/test_message.cpp index c4bb09af1..13ea7f848 100644 --- a/cpp/tests/streaming/test_message.cpp +++ b/cpp/tests/streaming/test_message.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -137,3 +138,44 @@ TEST_F(StreamingMessage, CopyWithCallbacks) { EXPECT_EQ(m1.sequence_number(), m2.sequence_number()); } } + +namespace { +// test context description properties +constexpr ContentDescription cd{ + {{MemoryType::HOST, 10}, {MemoryType::DEVICE, 20}}, ContentDescription::Spillable::NO +}; +static_assert(cd.highest_memory_type_set() == MemoryType::DEVICE); +static_assert(cd.lowest_memory_type_set() == MemoryType::HOST); +static_assert(std::ranges::equal( + LowerMemoryTypesInclusive(cd.lowest_memory_type_set()), + std::span(MEMORY_TYPES).subspan(1) // [HOST] +)); + +constexpr ContentDescription cd2{ + {{MemoryType::DEVICE, 10}}, ContentDescription::Spillable::NO +}; +static_assert(cd2.highest_memory_type_set() == MemoryType::DEVICE); +static_assert(cd2.lowest_memory_type_set() == MemoryType::DEVICE); +static_assert(std::ranges::equal( + LowerMemoryTypesInclusive(cd2.lowest_memory_type_set()), + std::span(MEMORY_TYPES) +)); + +constexpr ContentDescription cd3{ + {{MemoryType::HOST, 10}}, ContentDescription::Spillable::NO +}; +static_assert(cd3.highest_memory_type_set() == MemoryType::HOST); +static_assert(cd3.lowest_memory_type_set() == MemoryType::HOST); +static_assert(std::ranges::equal( + LowerMemoryTypesInclusive(cd3.lowest_memory_type_set()), + std::span(MEMORY_TYPES).subspan(1) // [HOST] +)); + +constexpr ContentDescription cd4{}; +static_assert(cd4.highest_memory_type_set() == MemoryType::HOST); +static_assert(cd4.lowest_memory_type_set() == MemoryType::HOST); +static_assert(std::ranges::equal( + LowerMemoryTypesInclusive(cd4.lowest_memory_type_set()), + std::span(MEMORY_TYPES).subspan(1) // [HOST] +)); +} // namespace From d99d7c4f1bd089183aa6e13e26a56aeb58e6c141 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 10 Nov 2025 12:22:27 -0800 Subject: [PATCH 13/43] remove size checkl Signed-off-by: niranda perera --- cpp/include/rapidsmpf/buffer/content_description.hpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cpp/include/rapidsmpf/buffer/content_description.hpp b/cpp/include/rapidsmpf/buffer/content_description.hpp index 5a34569ef..0cbab251f 100644 --- a/cpp/include/rapidsmpf/buffer/content_description.hpp +++ b/cpp/include/rapidsmpf/buffer/content_description.hpp @@ -66,10 +66,6 @@ class ContentDescription { std::pair> constexpr explicit ContentDescription(Range&& sizes, Spillable spillable) : spillable_(spillable == Spillable::YES) { - RAPIDSMPF_EXPECTS( - std::ranges::size(sizes) > 0, - "ContentDescription must have at least one memory type" - ); content_sizes_.fill(0); for (auto&& [mem_type, size] : sizes) { auto idx = static_cast(mem_type); From dccd9dd274abe1b18e0ed70f8b04a1bcaf3e412a Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 10 Nov 2025 14:20:25 -0800 Subject: [PATCH 14/43] Revert "remove size checkl" This reverts commit d99d7c4f1bd089183aa6e13e26a56aeb58e6c141. --- cpp/include/rapidsmpf/buffer/content_description.hpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/include/rapidsmpf/buffer/content_description.hpp b/cpp/include/rapidsmpf/buffer/content_description.hpp index 0cbab251f..5a34569ef 100644 --- a/cpp/include/rapidsmpf/buffer/content_description.hpp +++ b/cpp/include/rapidsmpf/buffer/content_description.hpp @@ -66,6 +66,10 @@ class ContentDescription { std::pair> constexpr explicit ContentDescription(Range&& sizes, Spillable spillable) : spillable_(spillable == Spillable::YES) { + RAPIDSMPF_EXPECTS( + std::ranges::size(sizes) > 0, + "ContentDescription must have at least one memory type" + ); content_sizes_.fill(0); for (auto&& [mem_type, size] : sizes) { auto idx = static_cast(mem_type); From 33eafbc6518c4a556f80d842dda5a2362e25c139 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 10 Nov 2025 14:20:59 -0800 Subject: [PATCH 15/43] Revert "adding lower mem types" This reverts commit ea945ca7cf11de12cdb1df31fcff05d69719bc92. --- cpp/include/rapidsmpf/buffer/buffer.hpp | 16 +------ .../rapidsmpf/buffer/content_description.hpp | 39 ----------------- cpp/include/rapidsmpf/buffer/resource.hpp | 11 ----- .../rapidsmpf/streaming/core/fanout.hpp | 11 ++++- cpp/src/buffer/resource.cpp | 9 +--- cpp/src/streaming/core/fanout.cpp | 22 +++------- cpp/tests/streaming/test_message.cpp | 42 ------------------- 7 files changed, 17 insertions(+), 133 deletions(-) diff --git a/cpp/include/rapidsmpf/buffer/buffer.hpp b/cpp/include/rapidsmpf/buffer/buffer.hpp index 13d9fed4f..103e1affd 100644 --- a/cpp/include/rapidsmpf/buffer/buffer.hpp +++ b/cpp/include/rapidsmpf/buffer/buffer.hpp @@ -24,7 +24,7 @@ namespace rapidsmpf { /// @brief Enum representing the type of memory. -enum class MemoryType : uint8_t { +enum class MemoryType : int { DEVICE = 0, ///< Device memory HOST = 1 ///< Host memory }; @@ -36,20 +36,6 @@ constexpr MemoryType LowestSpillType = MemoryType::HOST; /// @note Ensure that this array is always sorted in decreasing order of preference. constexpr std::array MEMORY_TYPES{{MemoryType::DEVICE, MemoryType::HOST}}; -/** - * @brief Get the lower memory types inclusive of the given memory type. - * - * @param mem_type The memory type to get the lower memory types for. - * @return A span of the lower memory types inclusive of the given memory type. - */ -constexpr std::span LowerMemoryTypesInclusive( - MemoryType mem_type -) noexcept { - return std::span{ - MEMORY_TYPES.begin() + static_cast(mem_type), MEMORY_TYPES.end() - }; -} - /** * @brief Buffer representing device or host memory. * diff --git a/cpp/include/rapidsmpf/buffer/content_description.hpp b/cpp/include/rapidsmpf/buffer/content_description.hpp index 5a34569ef..646a876aa 100644 --- a/cpp/include/rapidsmpf/buffer/content_description.hpp +++ b/cpp/include/rapidsmpf/buffer/content_description.hpp @@ -3,9 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include #include -#include #include #include @@ -66,10 +64,6 @@ class ContentDescription { std::pair> constexpr explicit ContentDescription(Range&& sizes, Spillable spillable) : spillable_(spillable == Spillable::YES) { - RAPIDSMPF_EXPECTS( - std::ranges::size(sizes) > 0, - "ContentDescription must have at least one memory type" - ); content_sizes_.fill(0); for (auto&& [mem_type, size] : sizes) { auto idx = static_cast(mem_type); @@ -110,39 +104,6 @@ class ContentDescription { return content_sizes_[static_cast(mem_type)]; } - /** - * @brief Get the highest memory type that has a non-zero size. - * - * @return The highest memory type that has a non-zero size. If no memory type has a - * non-zero size, returns `MemoryType::HOST`. - */ - constexpr MemoryType highest_memory_type_set() const noexcept { - auto it = std::ranges::find_if(content_sizes_, [](auto const& size) { - return size > 0; - }); - return it == content_sizes_.end() - ? MemoryType::HOST - : static_cast(std::distance(content_sizes_.begin(), it)); - } - - /** - * @brief Get the lowest memory type that has a non-zero size. - * - * @return The lowest memory type that has a non-zero size. If no memory type has a - * non-zero size, returns `MemoryType::HOST`. - */ - constexpr MemoryType lowest_memory_type_set() const noexcept { - auto it = std::ranges::find_if( - std::ranges::reverse_view(content_sizes_), - [](auto const& size) { return size > 0; } - ); - return it == content_sizes_.rend() - ? MemoryType::HOST - : static_cast( - std::distance(content_sizes_.begin(), it.base()) - 1 - ); - } - /** * @brief Get the total content size across all memory types. * diff --git a/cpp/include/rapidsmpf/buffer/resource.hpp b/cpp/include/rapidsmpf/buffer/resource.hpp index ddd72f9a9..ebf5f490a 100644 --- a/cpp/include/rapidsmpf/buffer/resource.hpp +++ b/cpp/include/rapidsmpf/buffer/resource.hpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -266,16 +265,6 @@ class BufferResource { size_t size, std::optional mem_type = std::nullopt ); - /** - * @brief Make a memory reservation or fail. - * - * @param size The size of the buffer to allocate. - * @param mem_types The memory types to try to allocate the buffer from. - * @return A memory reservation. - * @throws std::runtime_error if no memory reservation was made. - */ - MemoryReservation reserve_or_fail(size_t size, std::span mem_types); - /** * @brief Consume a portion of the reserved memory. * diff --git a/cpp/include/rapidsmpf/streaming/core/fanout.hpp b/cpp/include/rapidsmpf/streaming/core/fanout.hpp index 59a64cc6d..01e027556 100644 --- a/cpp/include/rapidsmpf/streaming/core/fanout.hpp +++ b/cpp/include/rapidsmpf/streaming/core/fanout.hpp @@ -34,6 +34,10 @@ enum class FanoutPolicy : uint8_t { * This mode disables coordinated backpressure between outputs, allowing * consumers to process at independent rates, but can lead to unbounded * buffering and increased memory usage. + * + * @note Consumers might not receive any messages until *all* upstream + * messages have been sent, depending on the implementation and buffering + * strategy. */ UNBOUNDED, }; @@ -45,7 +49,9 @@ enum class FanoutPolicy : uint8_t { * them to all output channels according to the selected fanout policy, see * ::FanoutPolicy. * - * Each output channel receives a deep copy of the same message. + * Each output channel receives a shallow copy of the same message; no payload + * data is duplicated. All copies share the same underlying payload, ensuring + * zero-copy broadcast semantics. * * @param ctx The node context to use. * @param ch_in Input channel from which messages are received. @@ -55,6 +61,9 @@ enum class FanoutPolicy : uint8_t { * @return Streaming node representing the fanout operation. * * @throws std::invalid_argument If an unknown fanout policy is specified. + * + * @note Since messages are shallow-copied, releasing a payload (`release()`) + * is only valid on messages that hold exclusive ownership of the payload. */ Node fanout( std::shared_ptr ctx, diff --git a/cpp/src/buffer/resource.cpp b/cpp/src/buffer/resource.cpp index b040f5ccb..b35c859f1 100644 --- a/cpp/src/buffer/resource.cpp +++ b/cpp/src/buffer/resource.cpp @@ -105,14 +105,7 @@ MemoryReservation BufferResource::reserve_or_fail( } // try to allocate data buffer from memory types in order [DEVICE, HOST] - return reserve_or_fail(size, MEMORY_TYPES); -} - -MemoryReservation BufferResource::reserve_or_fail( - size_t size, std::span mem_types -) { - // try to allocate data buffer from memory types in order [DEVICE, HOST] - for (auto const& mem_type : mem_types) { + for (auto mem_type : MEMORY_TYPES) { auto [res, _] = reserve(mem_type, size, false); if (res.size() == size) { return std::move(res); diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 062ae5ba1..dae60c673 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -5,7 +5,6 @@ #include #include -#include #include #include @@ -31,13 +30,10 @@ Node send_to_channels( ) { std::vector> tasks; tasks.reserve(chs_out.size()); - - auto try_memory_types = - LowerMemoryTypesInclusive(msg.content_description().lowest_memory_type_set()); for (auto& ch_out : chs_out) { // do a reservation for each copy, so that it will fallback to host memory if // needed - auto res = ctx->br()->reserve_or_fail(msg.copy_cost(), try_memory_types); + auto res = ctx->br()->reserve_or_fail(msg.copy_cost()); tasks.push_back(ch_out->send(msg.copy(res))); } coro_results(co_await coro::when_all(std::move(tasks))); @@ -73,12 +69,9 @@ Node bounded_fanout( logger.debug("Sent message ", msg.sequence_number()); } - std::vector> drain_tasks; - drain_tasks.reserve(chs_out.size()); - std::ranges::for_each(chs_out, [&](auto& ch) { - drain_tasks.emplace_back(ch->drain(ctx->executor())); - }); - coro_results(co_await coro::when_all(std::move(drain_tasks))); + for (auto& ch : chs_out) { + co_await ch->drain(ctx->executor()); + } logger.debug("Completed bounded fanout"); } @@ -147,12 +140,7 @@ Node unbounded_fo_send_task( // make reservations for each message so that it will fallback to host memory // if needed - - auto try_memory_types = LowerMemoryTypesInclusive( - msg.content_description().lowest_memory_type_set() - ); - - auto res = ctx.br()->reserve_or_fail(msg.copy_cost(), try_memory_types); + auto res = ctx.br()->reserve_or_fail(msg.copy_cost()); RAPIDSMPF_EXPECTS( co_await ch_out->send(msg.copy(res)), "failed to send message" ); diff --git a/cpp/tests/streaming/test_message.cpp b/cpp/tests/streaming/test_message.cpp index 13ea7f848..c4bb09af1 100644 --- a/cpp/tests/streaming/test_message.cpp +++ b/cpp/tests/streaming/test_message.cpp @@ -9,7 +9,6 @@ #include #include -#include #include #include @@ -138,44 +137,3 @@ TEST_F(StreamingMessage, CopyWithCallbacks) { EXPECT_EQ(m1.sequence_number(), m2.sequence_number()); } } - -namespace { -// test context description properties -constexpr ContentDescription cd{ - {{MemoryType::HOST, 10}, {MemoryType::DEVICE, 20}}, ContentDescription::Spillable::NO -}; -static_assert(cd.highest_memory_type_set() == MemoryType::DEVICE); -static_assert(cd.lowest_memory_type_set() == MemoryType::HOST); -static_assert(std::ranges::equal( - LowerMemoryTypesInclusive(cd.lowest_memory_type_set()), - std::span(MEMORY_TYPES).subspan(1) // [HOST] -)); - -constexpr ContentDescription cd2{ - {{MemoryType::DEVICE, 10}}, ContentDescription::Spillable::NO -}; -static_assert(cd2.highest_memory_type_set() == MemoryType::DEVICE); -static_assert(cd2.lowest_memory_type_set() == MemoryType::DEVICE); -static_assert(std::ranges::equal( - LowerMemoryTypesInclusive(cd2.lowest_memory_type_set()), - std::span(MEMORY_TYPES) -)); - -constexpr ContentDescription cd3{ - {{MemoryType::HOST, 10}}, ContentDescription::Spillable::NO -}; -static_assert(cd3.highest_memory_type_set() == MemoryType::HOST); -static_assert(cd3.lowest_memory_type_set() == MemoryType::HOST); -static_assert(std::ranges::equal( - LowerMemoryTypesInclusive(cd3.lowest_memory_type_set()), - std::span(MEMORY_TYPES).subspan(1) // [HOST] -)); - -constexpr ContentDescription cd4{}; -static_assert(cd4.highest_memory_type_set() == MemoryType::HOST); -static_assert(cd4.lowest_memory_type_set() == MemoryType::HOST); -static_assert(std::ranges::equal( - LowerMemoryTypesInclusive(cd4.lowest_memory_type_set()), - std::span(MEMORY_TYPES).subspan(1) // [HOST] -)); -} // namespace From a3752ca5ce6609eb52a37609d7224a340abd5e10 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 10 Nov 2025 15:00:39 -0800 Subject: [PATCH 16/43] addressing comments Signed-off-by: niranda perera --- .../rapidsmpf/streaming/core/fanout.hpp | 11 +------- cpp/src/streaming/core/fanout.cpp | 27 ++++++++++++++++--- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/cpp/include/rapidsmpf/streaming/core/fanout.hpp b/cpp/include/rapidsmpf/streaming/core/fanout.hpp index 01e027556..59a64cc6d 100644 --- a/cpp/include/rapidsmpf/streaming/core/fanout.hpp +++ b/cpp/include/rapidsmpf/streaming/core/fanout.hpp @@ -34,10 +34,6 @@ enum class FanoutPolicy : uint8_t { * This mode disables coordinated backpressure between outputs, allowing * consumers to process at independent rates, but can lead to unbounded * buffering and increased memory usage. - * - * @note Consumers might not receive any messages until *all* upstream - * messages have been sent, depending on the implementation and buffering - * strategy. */ UNBOUNDED, }; @@ -49,9 +45,7 @@ enum class FanoutPolicy : uint8_t { * them to all output channels according to the selected fanout policy, see * ::FanoutPolicy. * - * Each output channel receives a shallow copy of the same message; no payload - * data is duplicated. All copies share the same underlying payload, ensuring - * zero-copy broadcast semantics. + * Each output channel receives a deep copy of the same message. * * @param ctx The node context to use. * @param ch_in Input channel from which messages are received. @@ -61,9 +55,6 @@ enum class FanoutPolicy : uint8_t { * @return Streaming node representing the fanout operation. * * @throws std::invalid_argument If an unknown fanout policy is specified. - * - * @note Since messages are shallow-copied, releasing a payload (`release()`) - * is only valid on messages that hold exclusive ownership of the payload. */ Node fanout( std::shared_ptr ctx, diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index dae60c673..0dfa66525 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -5,7 +5,7 @@ #include #include -#include +#include #include #include @@ -18,6 +18,24 @@ namespace rapidsmpf::streaming::node { namespace { +/** + * @brief Try to allocate memory from the memory types that the message content uses. + * + * @param msg The message to allocate memory for. + * @return The memory types to try to allocate from. + */ +constexpr std::span try_memory_types(Message const& msg) { + auto const& cd = msg.content_description(); + // if the message content uses device memory, try to allocate from device memory + // first, else allocate from host memory + return cd.content_size(MemoryType::DEVICE) > 0 + ? MEMORY_TYPES + : std::span{ + MEMORY_TYPES.begin() + static_cast(MemoryType::HOST), + MEMORY_TYPES.end() + }; +} + /** * @brief Asynchronously send a message to multiple output channels. * @@ -33,7 +51,8 @@ Node send_to_channels( for (auto& ch_out : chs_out) { // do a reservation for each copy, so that it will fallback to host memory if // needed - auto res = ctx->br()->reserve_or_fail(msg.copy_cost()); + // TODO: change this + auto res = ctx->br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)[0]); tasks.push_back(ch_out->send(msg.copy(res))); } coro_results(co_await coro::when_all(std::move(tasks))); @@ -140,7 +159,9 @@ Node unbounded_fo_send_task( // make reservations for each message so that it will fallback to host memory // if needed - auto res = ctx.br()->reserve_or_fail(msg.copy_cost()); + // TODO: change this + auto res = + ctx.br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)[0]); RAPIDSMPF_EXPECTS( co_await ch_out->send(msg.copy(res)), "failed to send message" ); From a7453d6e3f9d3582297b87f825299d53d6712029 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 12 Nov 2025 14:17:11 -0800 Subject: [PATCH 17/43] addressing cpp comments Signed-off-by: niranda perera --- .../rapidsmpf/streaming/core/channel.hpp | 9 +- cpp/src/streaming/core/fanout.cpp | 170 ++++++++++-------- 2 files changed, 103 insertions(+), 76 deletions(-) diff --git a/cpp/include/rapidsmpf/streaming/core/channel.hpp b/cpp/include/rapidsmpf/streaming/core/channel.hpp index 64473dbd1..6163bbf6b 100644 --- a/cpp/include/rapidsmpf/streaming/core/channel.hpp +++ b/cpp/include/rapidsmpf/streaming/core/channel.hpp @@ -44,6 +44,8 @@ class Channel { * @param msg The msg to send. * @return A coroutine that evaluates to true if the msg was successfully sent or * false if the channel was shut down. + * + * @throws std::logic_error If the message is empty. */ coro::task send(Message msg); @@ -54,6 +56,8 @@ class Channel { * * @return A coroutine that evaluates to the message, which will be empty if the * channel is shut down. + * + * @throws std::logic_error If the received message is empty. */ coro::task receive(); @@ -294,9 +298,8 @@ class ShutdownAtExit { template explicit ShutdownAtExit(T&&... channels) requires(std::convertible_to> && ...) - : ShutdownAtExit( - std::vector>{std::forward(channels)...} - ) {} + : ShutdownAtExit(std::vector>{std::forward(channels + )...}) {} // Non-copyable, non-movable. ShutdownAtExit(ShutdownAtExit const&) = delete; diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 0dfa66525..de9fa297e 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -44,25 +44,27 @@ constexpr std::span try_memory_types(Message const& msg) { * @param chs_out The set of output channels to which the message is sent. */ Node send_to_channels( - Context* ctx, Message const& msg, std::vector>& chs_out + Context* ctx, Message&& msg, std::vector>& chs_out ) { std::vector> tasks; tasks.reserve(chs_out.size()); - for (auto& ch_out : chs_out) { + for (size_t i = 0; i < chs_out.size() - 1; i++) { // do a reservation for each copy, so that it will fallback to host memory if // needed // TODO: change this auto res = ctx->br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)[0]); - tasks.push_back(ch_out->send(msg.copy(res))); + tasks.emplace_back(chs_out[i]->send(msg.copy(res))); } + // move the message to the last channel to avoid extra copy + tasks.emplace_back(chs_out.back()->send(std::move(msg))); coro_results(co_await coro::when_all(std::move(tasks))); } /** * @brief Broadcast messages from one input channel to multiple output channels. * - * @note Bounded fanout requires all the output channels to consume messages before the - * next message is sent/consumed from the input channel. + * @note Bounded fanout requires all the output channels to consume messages before + * the next message is sent/consumed from the input channel. * * @param ctx The context to use. * @param ch_in The input channel to receive messages from. @@ -84,13 +86,16 @@ Node bounded_fanout( break; } - co_await send_to_channels(ctx.get(), msg, chs_out); + co_await send_to_channels(ctx.get(), std::move(msg), chs_out); logger.debug("Sent message ", msg.sequence_number()); } + std::vector drain_tasks; + drain_tasks.reserve(chs_out.size()); for (auto& ch : chs_out) { - co_await ch->drain(ctx->executor()); + drain_tasks.emplace_back(ch->drain(ctx->executor())); } + coro_results(co_await coro::when_all(std::move(drain_tasks))); logger.debug("Completed bounded fanout"); } @@ -128,8 +133,8 @@ struct UnboundedFanoutState { Node unbounded_fo_send_task( Context& ctx, size_t idx, - std::shared_ptr& ch_out, - UnboundedFanoutState& state + std::shared_ptr ch_out, + std::shared_ptr state ) { ShutdownAtExit ch_shutdown{ch_out}; co_await ctx.executor()->schedule(); @@ -139,22 +144,22 @@ Node unbounded_fo_send_task( size_t curr_recv_msg_sz = 0; // current size of the recv_messages deque while (true) { { - auto lock = co_await state.mtx.scoped_lock(); - co_await state.data_ready.wait(lock, [&] { + auto lock = co_await state->mtx.scoped_lock(); + co_await state->data_ready.wait(lock, [&] { // irrespective of input_done, update the end_idx to the total number of // messages - curr_recv_msg_sz = state.recv_messages.size(); - return state.input_done || state.ch_next_idx[idx] < curr_recv_msg_sz; + curr_recv_msg_sz = state->recv_messages.size(); + return state->input_done || state->ch_next_idx[idx] < curr_recv_msg_sz; }); - if (state.input_done && state.ch_next_idx[idx] == curr_recv_msg_sz) { + if (state->input_done && state->ch_next_idx[idx] == curr_recv_msg_sz) { // no more messages will be received, and all messages have been sent break; } } // now we can copy & send messages in indices [next_idx, end_idx) - for (size_t i = state.ch_next_idx[idx]; i < curr_recv_msg_sz; i++) { - auto const& msg = state.recv_messages[i]; + for (size_t i = state->ch_next_idx[idx]; i < curr_recv_msg_sz; i++) { + auto const& msg = state->recv_messages[i]; RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); // make reservations for each message so that it will fallback to host memory @@ -167,74 +172,51 @@ Node unbounded_fo_send_task( ); } logger.trace( - "sent ", idx, " [", state.ch_next_idx[idx], ", ", curr_recv_msg_sz, ")" + "sent ", idx, " [", state->ch_next_idx[idx], ", ", curr_recv_msg_sz, ")" ); // now next_idx can be updated to end_idx, and if !input_done, we need to request // parent task for more data - auto lock = co_await state.mtx.scoped_lock(); - state.ch_next_idx[idx] = curr_recv_msg_sz; - if (state.ch_next_idx[idx] == state.recv_messages.size()) { - if (state.input_done) { + auto lock = co_await state->mtx.scoped_lock(); + state->ch_next_idx[idx] = curr_recv_msg_sz; + if (state->ch_next_idx[idx] == state->recv_messages.size()) { + if (state->input_done) { break; // no more messages will be received, and all messages have been // sent } else { // request more data from the input channel lock.unlock(); - co_await state.request_data.notify_one(); + co_await state->request_data.notify_one(); } } } co_await ch_out->drain(ctx.executor()); - logger.debug("Send task ", idx, " completed"); + logger.trace("Send task ", idx, " completed"); } -/** - * @brief Broadcast messages from one input channel to multiple output channels. - * - * This is a general purpose implementation which can support consuming messages by any - * channel. A consumer node can decide to consume all messages from a single channel - * before moving to the next channel, or it can consume messages from all channels before - * moving to the next message. When a message has been sent to all output channels, it is - * purged from the internal deque. - * - * @param ctx The context to use. - * @param ch_in The input channel to receive messages from. - * @param chs_out The output channels to send messages to. - * @return A node representing the unbounded fanout operation. - */ -Node unbounded_fanout( - std::shared_ptr ctx, +Node unbounded_fo_process_input_task( + Context& ctx, std::shared_ptr ch_in, - std::vector> chs_out + std::shared_ptr state ) { ShutdownAtExit ch_in_shutdown{ch_in}; - ShutdownAtExit chs_out_shutdown{chs_out}; - co_await ctx->executor()->schedule(); - - auto& logger = ctx->logger(); - - logger.debug("Scheduled unbounded fanout"); - UnboundedFanoutState state(chs_out.size()); + co_await ctx.executor()->schedule(); + auto& logger = ctx.logger(); - // start send tasks for each output channel - coro::task_container tasks(ctx->executor()); - for (size_t i = 0; i < chs_out.size(); i++) { - RAPIDSMPF_EXPECTS( - tasks.start(unbounded_fo_send_task(*ctx, i, chs_out[i], state)), - "failed to start send task" - ); - } + logger.trace("Scheduled process input task"); // input_done is only set by this task, so reading without lock is safe here - while (!state.input_done) { + while (!state->input_done) { + size_t last_completed_idx, latest_processed_idx; { - auto lock = co_await state.mtx.scoped_lock(); - co_await state.request_data.wait(lock, [&] { - return std::ranges::any_of(state.ch_next_idx, [&](size_t next_idx) { - return state.recv_messages.size() == next_idx; - }); + auto lock = co_await state->mtx.scoped_lock(); + co_await state->request_data.wait(lock, [&] { + auto res = std::ranges::minmax(state->ch_next_idx); + last_completed_idx = res.min; + latest_processed_idx = res.max; + + return latest_processed_idx == state->recv_messages.size(); }); } @@ -242,17 +224,17 @@ Node unbounded_fanout( auto msg = co_await ch_in->receive(); { // relock mtx to update input_done/ recv_messages - auto lock = co_await state.mtx.scoped_lock(); + auto lock = co_await state->mtx.scoped_lock(); if (msg.empty()) { - state.input_done = true; + state->input_done = true; } else { logger.trace("Received input", msg.sequence_number()); - state.recv_messages.emplace_back(std::move(msg)); + state->recv_messages.emplace_back(std::move(msg)); } } // notify send_tasks to copy & send messages - co_await state.data_ready.notify_all(); + co_await state->data_ready.notify_all(); // purge completed send_tasks. This will reset the messages to empty, so that they // release the memory, however the deque is not resized. This guarantees that the @@ -260,19 +242,61 @@ Node unbounded_fanout( // intentionally not locking the mtx here, because we only need to know a // lower-bound on the last completed idx (ch_next_idx values are monotonically // increasing) - size_t last_completed_idx = std::ranges::min(state.ch_next_idx); - while (state.purge_idx + 1 < last_completed_idx) { - state.recv_messages[state.purge_idx].reset(); - state.purge_idx++; + while (state->purge_idx + 1 < last_completed_idx) { + state->recv_messages[state->purge_idx].reset(); + state->purge_idx++; } logger.trace( - "recv_messages active size: ", state.recv_messages.size() - state.purge_idx + "recv_messages active size: ", state->recv_messages.size() - state->purge_idx ); } - // Note: there will be some messages to be purged after the loop exits, but we don't - // need to do anything about them here - co_await tasks.yield_until_empty(); + co_await ch_in->drain(ctx.executor()); + logger.trace("Process input task completed"); +} + +/** + * @brief Broadcast messages from one input channel to multiple output channels. + * + * This is an all-purpose implementation that can support consuming messages by the + * channel order or message order. Output channels could be connected to single/multiple + * consumer nodes. A consumer node can decide to consume all messages from a single channel + * before moving to the next channel, or it can consume messages from all channels before + * moving to the next message. When a message has been sent to all output channels, it is + * purged from the internal deque. + * + * @param ctx The context to use. + * @param ch_in The input channel to receive messages from. + * @param chs_out The output channels to send messages to. + * @return A node representing the unbounded fanout operation. + */ +Node unbounded_fanout( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector> chs_out +) { + ShutdownAtExit ch_in_shutdown{ch_in}; + ShutdownAtExit chs_out_shutdown{chs_out}; + co_await ctx->executor()->schedule(); + auto& logger = ctx->logger(); + auto state = std::make_shared(chs_out.size()); + + std::vector tasks; + tasks.reserve(chs_out.size() + 1); + + auto& executor = *ctx->executor(); + // schedule send tasks for each output channel + for (size_t i = 0; i < chs_out.size(); i++) { + tasks.emplace_back(executor.schedule( + unbounded_fo_send_task(*ctx, i, std::move(chs_out[i]), state) + )); + } + // schedule process input task + tasks.emplace_back(executor.schedule( + unbounded_fo_process_input_task(*ctx, std::move(ch_in), std::move(state)) + )); + + coro_results(co_await coro::when_all(std::move(tasks))); logger.debug("Unbounded fanout completed"); } From 969bf5e6de7a627bff41e49abcd1152d510af87e Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 12 Nov 2025 14:45:51 -0800 Subject: [PATCH 18/43] addressing python comments Signed-off-by: niranda perera --- .../rapidsmpf/streaming/core/channel.hpp | 5 ++- cpp/src/streaming/core/fanout.cpp | 12 +++-- .../rapidsmpf/streaming/core/fanout.pxd | 9 ++-- .../rapidsmpf/streaming/core/fanout.pyx | 45 +++---------------- .../rapidsmpf/tests/streaming/test_fanout.py | 28 +++--------- 5 files changed, 27 insertions(+), 72 deletions(-) diff --git a/cpp/include/rapidsmpf/streaming/core/channel.hpp b/cpp/include/rapidsmpf/streaming/core/channel.hpp index 6163bbf6b..901c7205f 100644 --- a/cpp/include/rapidsmpf/streaming/core/channel.hpp +++ b/cpp/include/rapidsmpf/streaming/core/channel.hpp @@ -298,8 +298,9 @@ class ShutdownAtExit { template explicit ShutdownAtExit(T&&... channels) requires(std::convertible_to> && ...) - : ShutdownAtExit(std::vector>{std::forward(channels - )...}) {} + : ShutdownAtExit( + std::vector>{std::forward(channels)...} + ) {} // Non-copyable, non-movable. ShutdownAtExit(ShutdownAtExit const&) = delete; diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index de9fa297e..323e63ccb 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -46,6 +46,8 @@ constexpr std::span try_memory_types(Message const& msg) { Node send_to_channels( Context* ctx, Message&& msg, std::vector>& chs_out ) { + RAPIDSMPF_EXPECTS(!chs_out.empty(), "output channels cannot be empty"); + std::vector> tasks; tasks.reserve(chs_out.size()); for (size_t i = 0; i < chs_out.size() - 1; i++) { @@ -260,10 +262,10 @@ Node unbounded_fo_process_input_task( * * This is an all-purpose implementation that can support consuming messages by the * channel order or message order. Output channels could be connected to single/multiple - * consumer nodes. A consumer node can decide to consume all messages from a single channel - * before moving to the next channel, or it can consume messages from all channels before - * moving to the next message. When a message has been sent to all output channels, it is - * purged from the internal deque. + * consumer nodes. A consumer node can decide to consume all messages from a single + * channel before moving to the next channel, or it can consume messages from all channels + * before moving to the next message. When a message has been sent to all output channels, + * it is purged from the internal deque. * * @param ctx The context to use. * @param ch_in The input channel to receive messages from. @@ -308,6 +310,8 @@ Node fanout( std::vector> chs_out, FanoutPolicy policy ) { + RAPIDSMPF_EXPECTS(!chs_out.empty(), "output channels cannot be empty"); + switch (policy) { case FanoutPolicy::BOUNDED: return bounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd index 50ea90b32..813d17617 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pxd @@ -12,15 +12,14 @@ from rapidsmpf.streaming.core.node cimport cpp_Node cdef extern from "" \ namespace "rapidsmpf::streaming::node" nogil: - cdef enum class cpp_FanoutPolicy \ - "rapidsmpf::streaming::node::FanoutPolicy" (uint8_t): - BOUNDED "rapidsmpf::streaming::node::FanoutPolicy::BOUNDED" - UNBOUNDED "rapidsmpf::streaming::node::FanoutPolicy::UNBOUNDED" + cpdef enum class FanoutPolicy (uint8_t): + BOUNDED + UNBOUNDED cdef cpp_Node cpp_fanout \ "rapidsmpf::streaming::node::fanout"( shared_ptr[cpp_Context] ctx, shared_ptr[cpp_Channel] ch_in, vector[shared_ptr[cpp_Channel]] chs_out, - cpp_FanoutPolicy policy + FanoutPolicy policy ) except + diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx index 889318a6b..159022f8f 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx @@ -1,47 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 -from enum import IntEnum - -from libcpp.memory cimport make_unique, shared_ptr +from libcpp.memory cimport make_unique from libcpp.utility cimport move from libcpp.vector cimport vector from rapidsmpf.streaming.core.channel cimport Channel from rapidsmpf.streaming.core.context cimport Context +from rapidsmpf.streaming.core.fanout cimport FanoutPolicy from rapidsmpf.streaming.core.node cimport CppNode, cpp_Node -class FanoutPolicy(IntEnum): - """ - Fanout policy controlling how messages are propagated. - - Attributes - ---------- - BOUNDED : int - Process messages as they arrive and immediately forward them. - Messages are forwarded as soon as they are received from the input channel. - The next message is not processed until all output channels have completed - sending the current one, ensuring backpressure and synchronized flow. - UNBOUNDED : int - Forward messages without enforcing backpressure. - In this mode, messages may be accumulated internally before being - broadcast, or they may be forwarded immediately depending on the - implementation and downstream consumption rate. - - This mode disables coordinated backpressure between outputs, allowing - consumers to process at independent rates, but can lead to unbounded - buffering and increased memory usage. - - Note: Consumers might not receive any messages until *all* upstream - messages have been sent, depending on the implementation and buffering - strategy. - """ - BOUNDED = cpp_FanoutPolicy.BOUNDED - UNBOUNDED = cpp_FanoutPolicy.UNBOUNDED - - -def fanout(Context ctx, Channel ch_in, list chs_out, policy): +def fanout(Context ctx, Channel ch_in, list chs_out, FanoutPolicy policy): """ Broadcast messages from one input channel to multiple output channels. @@ -89,11 +59,9 @@ def fanout(Context ctx, Channel ch_in, list chs_out, policy): ... ctx, ch_in, [ch_out1, ch_out2], streaming.FanoutPolicy.BOUNDED ... ) """ - # Validate policy - if not isinstance(policy, (FanoutPolicy, int)): - raise TypeError(f"policy must be a FanoutPolicy enum value, got {type(policy)}") - cdef vector[shared_ptr[cpp_Channel]] _chs_out + if len(chs_out) == 0: + raise ValueError("output channels cannot be empty") owner = [] for ch_out in chs_out: if not isinstance(ch_out, Channel): @@ -101,10 +69,9 @@ def fanout(Context ctx, Channel ch_in, list chs_out, policy): owner.append(ch_out) _chs_out.push_back((ch_out)._handle) - cdef cpp_FanoutPolicy _policy = (policy) cdef cpp_Node _ret with nogil: _ret = cpp_fanout( - ctx._handle, ch_in._handle, move(_chs_out), _policy + ctx._handle, ch_in._handle, move(_chs_out), policy ) return CppNode.from_handle(make_unique[cpp_Node](move(_ret)), owner) diff --git a/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py b/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py index a397011e0..8c38ff287 100644 --- a/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py +++ b/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py @@ -86,8 +86,9 @@ def test_fanout_basic(context: Context, stream: Stream, policy: FanoutPolicy) -> @pytest.mark.parametrize("num_outputs", [1, 3, 5]) +@pytest.mark.parametrize("policy", [FanoutPolicy.BOUNDED, FanoutPolicy.UNBOUNDED]) def test_fanout_multiple_outputs( - context: Context, stream: Stream, num_outputs: int + context: Context, stream: Stream, num_outputs: int, policy: FanoutPolicy ) -> None: """Test fanout with varying numbers of output channels.""" # Create channels @@ -107,7 +108,7 @@ def test_fanout_multiple_outputs( # Create nodes push_node = push_to_channel(context, ch_in, messages) - fanout_node = fanout(context, ch_in, chs_out, FanoutPolicy.BOUNDED) + fanout_node = fanout(context, ch_in, chs_out, policy) pull_nodes = [] outputs = [] for ch_out in chs_out: @@ -133,27 +134,10 @@ def test_fanout_multiple_outputs( def test_fanout_empty_outputs(context: Context, stream: Stream) -> None: - """Test fanout with empty output list raises appropriate error or handles gracefully.""" + """Test fanout with empty output list raises value error.""" ch_in: Channel[TableChunk] = context.create_channel() - - # This should work but produce no output (or could raise ValueError) - # The C++ implementation may handle this differently - fanout_node = fanout(context, ch_in, [], FanoutPolicy.BOUNDED) - - # Create a simple message - df = cudf.DataFrame({"a": [1, 2, 3]}) - chunk = TableChunk.from_pylibcudf_table( - cudf_to_pylibcudf_table(df), stream, exclusive_view=False - ) - messages = [Message(0, chunk)] - push_node = push_to_channel(context, ch_in, messages) - - # Run pipeline - should complete without error - with ThreadPoolExecutor(max_workers=1) as executor: - run_streaming_pipeline( - nodes=[push_node, fanout_node], - py_executor=executor, - ) + with pytest.raises(ValueError): + fanout(context, ch_in, [], FanoutPolicy.BOUNDED) def test_fanout_policy_enum() -> None: From 06ec034e12298523d63d2ccbc073d32474b235e9 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 12 Nov 2025 14:51:13 -0800 Subject: [PATCH 19/43] minor Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 323e63ccb..4fb9999a3 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -159,7 +159,9 @@ Node unbounded_fo_send_task( } } - // now we can copy & send messages in indices [next_idx, end_idx) + // now we can copy & send messages in indices [next_idx, curr_recv_msg_sz) + // it is guaranteed that message purging will be done only on indices less than + // next_idx, so we can safely send messages without locking the mtx for (size_t i = state->ch_next_idx[idx]; i < curr_recv_msg_sz; i++) { auto const& msg = state->recv_messages[i]; RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); From 4936a2bdfb4587a03ad98f3c5af67ce494886a8b Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 13 Nov 2025 10:05:49 -0800 Subject: [PATCH 20/43] allowing output chs to shutdown prematurely Signed-off-by: niranda perera --- cpp/src/streaming/core/channel.cpp | 7 +- cpp/src/streaming/core/fanout.cpp | 106 +++++++++++++++++++++-------- 2 files changed, 82 insertions(+), 31 deletions(-) diff --git a/cpp/src/streaming/core/channel.cpp b/cpp/src/streaming/core/channel.cpp index 55486395a..0420bedbe 100644 --- a/cpp/src/streaming/core/channel.cpp +++ b/cpp/src/streaming/core/channel.cpp @@ -15,10 +15,9 @@ coro::task Channel::send(Message msg) { } coro::task Channel::receive() { - auto msg = co_await rb_.consume(); - if (msg.has_value()) { - RAPIDSMPF_EXPECTS(!msg->empty(), "received empty message"); - co_return sm_->extract(*msg); + auto msg_id = co_await rb_.consume(); + if (msg_id.has_value()) { + co_return sm_->extract(msg_id.value()); } else { co_return Message{}; } diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 4fb9999a3..2ef4f6ac3 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -53,8 +53,7 @@ Node send_to_channels( for (size_t i = 0; i < chs_out.size() - 1; i++) { // do a reservation for each copy, so that it will fallback to host memory if // needed - // TODO: change this - auto res = ctx->br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)[0]); + auto res = ctx->br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)); tasks.emplace_back(chs_out[i]->send(msg.copy(res))); } // move the message to the last channel to avoid extra copy @@ -123,6 +122,12 @@ struct UnboundedFanoutState { size_t purge_idx{0}; }; +/** + * @brief Sentinel value indicating that the index is invalid. This is set when a failure + * occurs during send tasks. process input task will filter out messages with this index. + */ +constexpr size_t InvalidIdx = std::numeric_limits::max(); + /** * @brief Send messages to multiple output channels. * @@ -168,12 +173,18 @@ Node unbounded_fo_send_task( // make reservations for each message so that it will fallback to host memory // if needed - // TODO: change this - auto res = - ctx.br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)[0]); - RAPIDSMPF_EXPECTS( - co_await ch_out->send(msg.copy(res)), "failed to send message" - ); + auto res = ctx.br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)); + if (!co_await ch_out->send(msg.copy(res))) { + // Failed to send message. Could be that the channel is shut down. + // So we need to abort the send task, and notify + // the process input task + { + auto lock = co_await state->mtx.scoped_lock(); + state->ch_next_idx[idx] = InvalidIdx; + } + co_await state->data_ready.notify_one(); + co_return; + } } logger.trace( "sent ", idx, " [", state->ch_next_idx[idx], ", ", curr_recv_msg_sz, ")" @@ -185,8 +196,8 @@ Node unbounded_fo_send_task( state->ch_next_idx[idx] = curr_recv_msg_sz; if (state->ch_next_idx[idx] == state->recv_messages.size()) { if (state->input_done) { - break; // no more messages will be received, and all messages have been - // sent + // no more messages will be received, and all messages have been sent + break; } else { // request more data from the input channel lock.unlock(); @@ -199,12 +210,38 @@ Node unbounded_fo_send_task( logger.trace("Send task ", idx, " completed"); } +/** + * @brief RAII helper class to close the unbounded fanout state when it goes out of + * scope. + */ +struct UnboundedFanoutStateCloser { + std::shared_ptr state; + + ~UnboundedFanoutStateCloser() { + // forcibly set input_done to true and notify all send tasks to wind down + coro::sync_wait([](auto&& s) -> coro::task { + auto lock = co_await s->mtx.scoped_lock(); + s->input_done = true; + co_await s->data_ready.notify_all(); + }(std::move(state))); + } +}; + +/** + * @brief Process input messages and notify send tasks to copy & send messages. + * + * @param ctx The context to use. + * @param ch_in The input channel to receive messages from. + * @param state The state of the unbounded fanout. + * @return A coroutine representing the task. + */ Node unbounded_fo_process_input_task( Context& ctx, std::shared_ptr ch_in, std::shared_ptr state ) { ShutdownAtExit ch_in_shutdown{ch_in}; + UnboundedFanoutStateCloser state_closer{state}; co_await ctx.executor()->schedule(); auto& logger = ctx.logger(); @@ -212,18 +249,28 @@ Node unbounded_fo_process_input_task( // input_done is only set by this task, so reading without lock is safe here while (!state->input_done) { - size_t last_completed_idx, latest_processed_idx; + size_t last_completed_idx = InvalidIdx, latest_processed_idx = 0; { auto lock = co_await state->mtx.scoped_lock(); co_await state->request_data.wait(lock, [&] { - auto res = std::ranges::minmax(state->ch_next_idx); - last_completed_idx = res.min; - latest_processed_idx = res.max; - - return latest_processed_idx == state->recv_messages.size(); + for (auto idx : state->ch_next_idx) { + if (idx != InvalidIdx) { + last_completed_idx = std::min(last_completed_idx, idx); + latest_processed_idx = std::max(latest_processed_idx, idx); + } + } + // if min idx was never updated, that means all send tasks are in an + // invalid state + return (last_completed_idx == InvalidIdx) + || (latest_processed_idx == state->recv_messages.size()); }); } + // all send tasks are in an invalid state, so we can break + if (last_completed_idx == InvalidIdx) { + break; + } + // receive a message from the input channel auto msg = co_await ch_in->receive(); @@ -240,12 +287,11 @@ Node unbounded_fo_process_input_task( // notify send_tasks to copy & send messages co_await state->data_ready.notify_all(); - // purge completed send_tasks. This will reset the messages to empty, so that they - // release the memory, however the deque is not resized. This guarantees that the - // indices are not invalidated. - // intentionally not locking the mtx here, because we only need to know a - // lower-bound on the last completed idx (ch_next_idx values are monotonically - // increasing) + // purge completed send_tasks. This will reset the messages to empty, so that + // they release the memory, however the deque is not resized. This guarantees + // that the indices are not invalidated. intentionally not locking the mtx + // here, because we only need to know a lower-bound on the last completed idx + // (ch_next_idx values are monotonically increasing) while (state->purge_idx + 1 < last_completed_idx) { state->recv_messages[state->purge_idx].reset(); state->purge_idx++; @@ -263,11 +309,11 @@ Node unbounded_fo_process_input_task( * @brief Broadcast messages from one input channel to multiple output channels. * * This is an all-purpose implementation that can support consuming messages by the - * channel order or message order. Output channels could be connected to single/multiple - * consumer nodes. A consumer node can decide to consume all messages from a single - * channel before moving to the next channel, or it can consume messages from all channels - * before moving to the next message. When a message has been sent to all output channels, - * it is purged from the internal deque. + * channel order or message order. Output channels could be connected to + * single/multiple consumer nodes. A consumer node can decide to consume all messages + * from a single channel before moving to the next channel, or it can consume messages + * from all channels before moving to the next message. When a message has been sent + * to all output channels, it is purged from the internal deque. * * @param ctx The context to use. * @param ch_in The input channel to receive messages from. @@ -314,6 +360,12 @@ Node fanout( ) { RAPIDSMPF_EXPECTS(!chs_out.empty(), "output channels cannot be empty"); + // if there is only one output channel, both bounded and unbounded implementations are + // semantically equivalent. So we can use the bounded fanout implementation. + if (chs_out.size() == 1) { + return bounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); + } + switch (policy) { case FanoutPolicy::BOUNDED: return bounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); From 0856d964f77f1c335c7968788385099a74c6cbb1 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 13 Nov 2025 15:20:55 -0800 Subject: [PATCH 21/43] working premature shutdown Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 74 ++++++++++++++++------------ cpp/src/streaming/core/leaf_node.cpp | 4 +- cpp/tests/streaming/test_fanout.cpp | 70 ++++++++++++++++++++++++-- 3 files changed, 111 insertions(+), 37 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 2ef4f6ac3..6e3ee0a92 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -58,6 +58,9 @@ Node send_to_channels( } // move the message to the last channel to avoid extra copy tasks.emplace_back(chs_out.back()->send(std::move(msg))); + + // note that the send tasks may return false if the channel is shut down. But it does + // not affect the bounded fanout implementation. coro_results(co_await coro::when_all(std::move(tasks))); } @@ -88,7 +91,7 @@ Node bounded_fanout( } co_await send_to_channels(ctx.get(), std::move(msg), chs_out); - logger.debug("Sent message ", msg.sequence_number()); + logger.trace("Sent message ", msg.sequence_number()); } std::vector drain_tasks; @@ -97,7 +100,7 @@ Node bounded_fanout( drain_tasks.emplace_back(ch->drain(ctx->executor())); } coro_results(co_await coro::when_all(std::move(drain_tasks))); - logger.debug("Completed bounded fanout"); + logger.trace("Completed bounded fanout"); } /** @@ -152,6 +155,7 @@ Node unbounded_fo_send_task( while (true) { { auto lock = co_await state->mtx.scoped_lock(); + logger.trace("before data_ready wait ", idx); co_await state->data_ready.wait(lock, [&] { // irrespective of input_done, update the end_idx to the total number of // messages @@ -176,13 +180,13 @@ Node unbounded_fo_send_task( auto res = ctx.br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)); if (!co_await ch_out->send(msg.copy(res))) { // Failed to send message. Could be that the channel is shut down. - // So we need to abort the send task, and notify - // the process input task + // So we need to abort the send task, and notify the process input task { auto lock = co_await state->mtx.scoped_lock(); state->ch_next_idx[idx] = InvalidIdx; } - co_await state->data_ready.notify_one(); + // notify the process input task to check if it should break + co_await state->request_data.notify_one(); co_return; } } @@ -211,19 +215,23 @@ Node unbounded_fo_send_task( } /** - * @brief RAII helper class to close the unbounded fanout state when it goes out of - * scope. + * @brief RAII helper class to set input_done to true and notify all send tasks to wind + * down when the unbounded fanout state goes out of scope. */ -struct UnboundedFanoutStateCloser { +struct StateInputDoneAtExit { std::shared_ptr state; - ~UnboundedFanoutStateCloser() { - // forcibly set input_done to true and notify all send tasks to wind down - coro::sync_wait([](auto&& s) -> coro::task { - auto lock = co_await s->mtx.scoped_lock(); - s->input_done = true; - co_await s->data_ready.notify_all(); - }(std::move(state))); + // forcibly set input_done to true and notify all send tasks to wind down + Node set_input_done() { + { + auto lock = co_await state->mtx.scoped_lock(); + state->input_done = true; + } + co_await state->data_ready.notify_all(); + } + + ~StateInputDoneAtExit() { + coro::sync_wait(set_input_done()); } }; @@ -241,7 +249,7 @@ Node unbounded_fo_process_input_task( std::shared_ptr state ) { ShutdownAtExit ch_in_shutdown{ch_in}; - UnboundedFanoutStateCloser state_closer{state}; + StateInputDoneAtExit state_closer{state}; co_await ctx.executor()->schedule(); auto& logger = ctx.logger(); @@ -249,26 +257,31 @@ Node unbounded_fo_process_input_task( // input_done is only set by this task, so reading without lock is safe here while (!state->input_done) { - size_t last_completed_idx = InvalidIdx, latest_processed_idx = 0; + size_t last_completed_idx = InvalidIdx, latest_processed_idx = InvalidIdx; { auto lock = co_await state->mtx.scoped_lock(); + logger.trace("before request_data wait"); co_await state->request_data.wait(lock, [&] { - for (auto idx : state->ch_next_idx) { - if (idx != InvalidIdx) { - last_completed_idx = std::min(last_completed_idx, idx); - latest_processed_idx = std::max(latest_processed_idx, idx); - } + auto filtered = state->ch_next_idx | std::views::filter([](size_t idx) { + return idx != InvalidIdx; + }); + + if (std::ranges::empty(filtered)) { + // no valid indices, so all send tasks are in an invalid state + return true; } - // if min idx was never updated, that means all send tasks are in an - // invalid state - return (last_completed_idx == InvalidIdx) - || (latest_processed_idx == state->recv_messages.size()); + + auto [min_val, max_val] = std::ranges::minmax(filtered); + last_completed_idx = min_val; + latest_processed_idx = max_val; + + return latest_processed_idx == state->recv_messages.size(); }); - } - // all send tasks are in an invalid state, so we can break - if (last_completed_idx == InvalidIdx) { - break; + // all send tasks are in an invalid state, so we can break + if (last_completed_idx == InvalidIdx && latest_processed_idx == InvalidIdx) { + break; + } } // receive a message from the input channel @@ -279,7 +292,6 @@ Node unbounded_fo_process_input_task( if (msg.empty()) { state->input_done = true; } else { - logger.trace("Received input", msg.sequence_number()); state->recv_messages.emplace_back(std::move(msg)); } } diff --git a/cpp/src/streaming/core/leaf_node.cpp b/cpp/src/streaming/core/leaf_node.cpp index 74f259126..c2b380f1c 100644 --- a/cpp/src/streaming/core/leaf_node.cpp +++ b/cpp/src/streaming/core/leaf_node.cpp @@ -18,9 +18,7 @@ Node push_to_channel( for (auto& msg : messages) { RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty", std::invalid_argument); - RAPIDSMPF_EXPECTS( - co_await ch_out->send(std::move(msg)), "failed to send message" - ); + co_await ch_out->send(std::move(msg)); } co_await ch_out->drain(ctx->executor()); } diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index 0bff6ee33..0f4ca7dd7 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -129,6 +129,70 @@ TEST_P(StreamingFanout, SinkPerChannel) { } } +namespace { + +Node shutdown_channel_after_n_messages( + std::shared_ptr ctx, + std::shared_ptr ch_in, + std::vector& out_messages, + size_t max_messages +) { + ShutdownAtExit c{ch_in}; + co_await ctx->executor()->schedule(); + + for (size_t i = 0; i < max_messages; ++i) { + auto msg = co_await ch_in->receive(); + if (msg.empty()) { + break; + } + out_messages.push_back(std::move(msg)); + } + co_await ch_in->shutdown(); +} + +} // namespace + +TEST_P(StreamingFanout, SinkPerChannel_ShutdownHalfWay) { + // Prepare inputs + auto inputs = make_int_inputs(num_msgs); + + // Create pipeline + std::vector> outs(num_out_chs); + { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + for (int i = 0; i < num_out_chs; ++i) { + nodes.emplace_back( + shutdown_channel_after_n_messages(ctx, out_chs[i], outs[i], num_msgs / 2) + ); + } + + run_streaming_pipeline(std::move(nodes)); + } + + for (int c = 0; c < num_out_chs; ++c) { + // Validate sizes + EXPECT_EQ(outs[c].size(), static_cast(num_msgs / 2)); + + // Validate ordering/content and that shallow copies share the same underlying + // object + for (int i = 0; i < num_msgs / 2; ++i) { + SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); + EXPECT_EQ(outs[c][i].get(), i); + } + } +} + enum class ConsumePolicy : uint8_t { CHANNEL_ORDER, // consume all messages from a single channel before moving to the // next @@ -204,9 +268,9 @@ struct ManyInputSinkStreamingFanout : public StreamingFanout { std::vector actual; actual.reserve(outs[c].size()); std::ranges::transform( - outs[c], std::back_inserter(actual), [](const Message& m) { - return m.get(); - } + outs[c], + std::back_inserter(actual), + [](const Message& m) { return m.get(); } ); EXPECT_EQ(expected, actual); } From f740a744106c71e74354e22c99bd93d29f2182dc Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 13 Nov 2025 16:00:54 -0800 Subject: [PATCH 22/43] minor improvements Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 60 ++++++++++++++++--------- cpp/tests/streaming/test_fanout.cpp | 69 ++++++++++++++++++++++++----- 2 files changed, 97 insertions(+), 32 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 6e3ee0a92..33ceef23f 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -39,8 +39,8 @@ constexpr std::span try_memory_types(Message const& msg) { /** * @brief Asynchronously send a message to multiple output channels. * - * @param msg The message to broadcast. Each channel receives a shallow - * copy of the original message. + * @param msg The message to broadcast. Each channel receives a deep copy of the original + * message. * @param chs_out The set of output channels to which the message is sent. */ Node send_to_channels( @@ -59,8 +59,8 @@ Node send_to_channels( // move the message to the last channel to avoid extra copy tasks.emplace_back(chs_out.back()->send(std::move(msg))); - // note that the send tasks may return false if the channel is shut down. But it does - // not affect the bounded fanout implementation. + // note that the send tasks may return false if the channel is shut down. But we can + // safely ignore this in bounded fanout. coro_results(co_await coro::when_all(std::move(tasks))); } @@ -131,6 +131,30 @@ struct UnboundedFanoutState { */ constexpr size_t InvalidIdx = std::numeric_limits::max(); +/** + * @brief RAII helper class to set a channel index to invalid and notify the process input + * task to check if it should break. + */ +struct SetChannelIdxInvalidAtExit { + std::shared_ptr state; + size_t idx; + + ~SetChannelIdxInvalidAtExit() { + coro::sync_wait(set_channel_idx_invalid()); + } + + Node set_channel_idx_invalid() { + if (state) { + { + auto lock = co_await state->mtx.scoped_lock(); + state->ch_next_idx[idx] = InvalidIdx; + } + co_await state->request_data.notify_one(); + } + state.reset(); + } +}; + /** * @brief Send messages to multiple output channels. * @@ -147,6 +171,7 @@ Node unbounded_fo_send_task( std::shared_ptr state ) { ShutdownAtExit ch_shutdown{ch_out}; + SetChannelIdxInvalidAtExit set_ch_idx_invalid{.state = state, .idx = idx}; co_await ctx.executor()->schedule(); auto& logger = ctx.logger(); @@ -155,7 +180,6 @@ Node unbounded_fo_send_task( while (true) { { auto lock = co_await state->mtx.scoped_lock(); - logger.trace("before data_ready wait ", idx); co_await state->data_ready.wait(lock, [&] { // irrespective of input_done, update the end_idx to the total number of // messages @@ -181,12 +205,7 @@ Node unbounded_fo_send_task( if (!co_await ch_out->send(msg.copy(res))) { // Failed to send message. Could be that the channel is shut down. // So we need to abort the send task, and notify the process input task - { - auto lock = co_await state->mtx.scoped_lock(); - state->ch_next_idx[idx] = InvalidIdx; - } - // notify the process input task to check if it should break - co_await state->request_data.notify_one(); + co_await set_ch_idx_invalid.set_channel_idx_invalid(); co_return; } } @@ -221,6 +240,10 @@ Node unbounded_fo_send_task( struct StateInputDoneAtExit { std::shared_ptr state; + ~StateInputDoneAtExit() { + coro::sync_wait(set_input_done()); + } + // forcibly set input_done to true and notify all send tasks to wind down Node set_input_done() { { @@ -229,10 +252,6 @@ struct StateInputDoneAtExit { } co_await state->data_ready.notify_all(); } - - ~StateInputDoneAtExit() { - coro::sync_wait(set_input_done()); - } }; /** @@ -260,18 +279,17 @@ Node unbounded_fo_process_input_task( size_t last_completed_idx = InvalidIdx, latest_processed_idx = InvalidIdx; { auto lock = co_await state->mtx.scoped_lock(); - logger.trace("before request_data wait"); co_await state->request_data.wait(lock, [&] { - auto filtered = state->ch_next_idx | std::views::filter([](size_t idx) { - return idx != InvalidIdx; - }); + auto filtered_view = std::ranges::filter_view( + state->ch_next_idx, [](size_t idx) { return idx != InvalidIdx; } + ); - if (std::ranges::empty(filtered)) { + if (std::ranges::empty(filtered_view)) { // no valid indices, so all send tasks are in an invalid state return true; } - auto [min_val, max_val] = std::ranges::minmax(filtered); + auto [min_val, max_val] = std::ranges::minmax(filtered_view); last_completed_idx = min_val; latest_processed_idx = max_val; diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index 0f4ca7dd7..d1418b159 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -91,10 +91,8 @@ INSTANTIATE_TEST_SUITE_P( ); TEST_P(StreamingFanout, SinkPerChannel) { - // Prepare inputs auto inputs = make_int_inputs(num_msgs); - // Create pipeline std::vector> outs(num_out_chs); { std::vector nodes; @@ -131,6 +129,16 @@ TEST_P(StreamingFanout, SinkPerChannel) { namespace { +/** + * @brief A node that pulls and shuts down a channel after a certain number of messages + * have been received. + * + * @param ctx The context to use. + * @param ch_in The input channel to receive messages from. + * @param out_messages The output messages to store the received messages in. + * @param max_messages The maximum number of messages to receive. + * @return A coroutine representing the task. + */ Node shutdown_channel_after_n_messages( std::shared_ptr ctx, std::shared_ptr ch_in, @@ -152,11 +160,10 @@ Node shutdown_channel_after_n_messages( } // namespace +// all channels shutsdown after receiving num_msgs / 2 messages TEST_P(StreamingFanout, SinkPerChannel_ShutdownHalfWay) { - // Prepare inputs auto inputs = make_int_inputs(num_msgs); - // Create pipeline std::vector> outs(num_out_chs); { std::vector nodes; @@ -181,11 +188,8 @@ TEST_P(StreamingFanout, SinkPerChannel_ShutdownHalfWay) { } for (int c = 0; c < num_out_chs; ++c) { - // Validate sizes - EXPECT_EQ(outs[c].size(), static_cast(num_msgs / 2)); + EXPECT_EQ(static_cast(num_msgs / 2), outs[c].size()); - // Validate ordering/content and that shallow copies share the same underlying - // object for (int i = 0; i < num_msgs / 2; ++i) { SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); EXPECT_EQ(outs[c][i].get(), i); @@ -193,6 +197,49 @@ TEST_P(StreamingFanout, SinkPerChannel_ShutdownHalfWay) { } } +// only odd channels shutdown after receiving num_msgs / 2 messages, others continue to +// receive all messages +TEST_P(StreamingFanout, SinkPerChannel_OddChannelsShutdownHalfWay) { + auto inputs = make_int_inputs(num_msgs); + + std::vector> outs(num_out_chs); + { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + for (int i = 0; i < num_out_chs; ++i) { + if (i % 2 == 0) { + nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], outs[i])); + } else { + nodes.emplace_back(shutdown_channel_after_n_messages( + ctx, out_chs[i], outs[i], num_msgs / 2 + )); + } + } + + run_streaming_pipeline(std::move(nodes)); + } + + for (int c = 0; c < num_out_chs; ++c) { + int expected_size = c % 2 == 0 ? num_msgs : num_msgs / 2; + EXPECT_EQ(outs[c].size(), expected_size); + + for (int i = 0; i < expected_size; ++i) { + SCOPED_TRACE("channel " + std::to_string(c) + " idx " + std::to_string(i)); + EXPECT_EQ(outs[c][i].get(), i); + } + } +} + enum class ConsumePolicy : uint8_t { CHANNEL_ORDER, // consume all messages from a single channel before moving to the // next @@ -268,9 +315,9 @@ struct ManyInputSinkStreamingFanout : public StreamingFanout { std::vector actual; actual.reserve(outs[c].size()); std::ranges::transform( - outs[c], - std::back_inserter(actual), - [](const Message& m) { return m.get(); } + outs[c], std::back_inserter(actual), [](const Message& m) { + return m.get(); + } ); EXPECT_EQ(expected, actual); } From 8e3e765f52510378d519f3e87eda3f34cd47e720 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 13 Nov 2025 16:10:15 -0800 Subject: [PATCH 23/43] cull comments Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 33ceef23f..739e4803c 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -365,13 +365,12 @@ Node unbounded_fanout( tasks.reserve(chs_out.size() + 1); auto& executor = *ctx->executor(); - // schedule send tasks for each output channel + for (size_t i = 0; i < chs_out.size(); i++) { tasks.emplace_back(executor.schedule( unbounded_fo_send_task(*ctx, i, std::move(chs_out[i]), state) )); } - // schedule process input task tasks.emplace_back(executor.schedule( unbounded_fo_process_input_task(*ctx, std::move(ch_in), std::move(state)) )); From 0ee662e403943e0905e716c423e74c7d3d3c45f8 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 13 Nov 2025 16:28:39 -0800 Subject: [PATCH 24/43] adding throwing tests Signed-off-by: niranda perera --- cpp/tests/streaming/test_fanout.cpp | 57 +++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index d1418b159..b47b46390 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -158,6 +158,12 @@ Node shutdown_channel_after_n_messages( co_await ch_in->shutdown(); } +Node throwing_node(std::shared_ptr ctx, std::shared_ptr ch_out) { + ShutdownAtExit c{ch_out}; + co_await ctx->executor()->schedule(); + throw std::logic_error("throwing source"); +} + } // namespace // all channels shutsdown after receiving num_msgs / 2 messages @@ -240,6 +246,57 @@ TEST_P(StreamingFanout, SinkPerChannel_OddChannelsShutdownHalfWay) { } } +// tests that throwing a source node propagates the error to the pipeline. This test will +// throw, but it should not hang. +TEST_P(StreamingFanout, SinkPerChannel_ThrowingSource) { + std::vector nodes; + + auto in = ctx->create_channel(); + nodes.emplace_back(throwing_node(ctx, in)); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + std::vector dummy_out; + for (int i = 0; i < num_out_chs; ++i) { + nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], dummy_out)); + } + + EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::logic_error); +} + +// tests that throwing a sink node propagates the error to the pipeline. This test +// will throw, but it should not hang. +TEST_P(StreamingFanout, SinkPerChannel_ThrowingSink) { + auto inputs = make_int_inputs(num_msgs); + + std::vector nodes; + auto in = ctx->create_channel(); + nodes.emplace_back(node::push_to_channel(ctx, in, std::move(inputs))); + + std::vector> out_chs; + for (int i = 0; i < num_out_chs; ++i) { + out_chs.emplace_back(ctx->create_channel()); + } + + nodes.emplace_back(node::fanout(ctx, in, out_chs, policy)); + + std::vector> dummy_outs(num_out_chs); + for (int i = 0; i < num_out_chs; ++i) { + if (i == 0) { + nodes.emplace_back(throwing_node(ctx, out_chs[i])); + } else { + nodes.emplace_back(node::pull_from_channel(ctx, out_chs[i], dummy_outs[i])); + } + } + + EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::logic_error); +} + enum class ConsumePolicy : uint8_t { CHANNEL_ORDER, // consume all messages from a single channel before moving to the // next From 26217baa5b896fd366e71ac54adec27c178b4450 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 13 Nov 2025 16:30:30 -0800 Subject: [PATCH 25/43] revert Signed-off-by: niranda perera --- cpp/src/streaming/core/channel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/streaming/core/channel.cpp b/cpp/src/streaming/core/channel.cpp index 0420bedbe..347d41790 100644 --- a/cpp/src/streaming/core/channel.cpp +++ b/cpp/src/streaming/core/channel.cpp @@ -17,7 +17,7 @@ coro::task Channel::send(Message msg) { coro::task Channel::receive() { auto msg_id = co_await rb_.consume(); if (msg_id.has_value()) { - co_return sm_->extract(msg_id.value()); + co_return sm_->extract(*msg_id); } else { co_return Message{}; } From da3a9fe512f64d9f0750f93f139c9aba9125c180 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 13 Nov 2025 17:13:19 -0800 Subject: [PATCH 26/43] reducing test permutations Signed-off-by: niranda perera --- cpp/tests/streaming/test_fanout.cpp | 48 ++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index b47b46390..f99744959 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -65,6 +65,11 @@ class StreamingFanout void SetUp() override { std::tie(policy, num_threads, num_out_chs, num_msgs) = GetParam(); SetUpWithThreads(num_threads); + + // restrict fanout tests to single communicator mode to reduce test runtime + if (GlobalEnvironment->type() != TestEnvironmentType::SINGLE) { + GTEST_SKIP() << "Skipping test in non-single communicator mode"; + } } FanoutPolicy policy; @@ -78,9 +83,9 @@ INSTANTIATE_TEST_SUITE_P( StreamingFanout, ::testing::Combine( ::testing::Values(FanoutPolicy::BOUNDED, FanoutPolicy::UNBOUNDED), - ::testing::Values(1, 2, 4), // number of threads - ::testing::Values(1, 2, 4), // number of output channels - ::testing::Values(1, 10, 100) // number of messages + ::testing::Values(1, 4), // number of threads + ::testing::Values(1, 4), // number of output channels + ::testing::Values(10, 100) // number of messages ), [](testing::TestParamInfo const& info) { return "policy_" + policy_to_string(std::get<0>(info.param)) + "_nthreads_" @@ -246,9 +251,28 @@ TEST_P(StreamingFanout, SinkPerChannel_OddChannelsShutdownHalfWay) { } } +class ThrowingStreamingFanout : public StreamingFanout {}; + +INSTANTIATE_TEST_SUITE_P( + ThrowingStreamingFanout, + ThrowingStreamingFanout, + ::testing::Combine( + ::testing::Values(FanoutPolicy::BOUNDED, FanoutPolicy::UNBOUNDED), + ::testing::Values(1, 4), // number of threads + ::testing::Values(4), // number of output channels + ::testing::Values(10) // number of messages + ), + [](testing::TestParamInfo const& info) { + return "policy_" + policy_to_string(std::get<0>(info.param)) + "_nthreads_" + + std::to_string(std::get<1>(info.param)) + "_nch_out_" + + std::to_string(std::get<2>(info.param)) + "_nmsgs_" + + std::to_string(std::get<3>(info.param)); + } +); + // tests that throwing a source node propagates the error to the pipeline. This test will // throw, but it should not hang. -TEST_P(StreamingFanout, SinkPerChannel_ThrowingSource) { +TEST_P(ThrowingStreamingFanout, ThrowingSource) { std::vector nodes; auto in = ctx->create_channel(); @@ -271,7 +295,7 @@ TEST_P(StreamingFanout, SinkPerChannel_ThrowingSource) { // tests that throwing a sink node propagates the error to the pipeline. This test // will throw, but it should not hang. -TEST_P(StreamingFanout, SinkPerChannel_ThrowingSink) { +TEST_P(ThrowingStreamingFanout, ThrowingSink) { auto inputs = make_int_inputs(num_msgs); std::vector nodes; @@ -297,6 +321,7 @@ TEST_P(StreamingFanout, SinkPerChannel_ThrowingSink) { EXPECT_THROW(run_streaming_pipeline(std::move(nodes)), std::logic_error); } +namespace { enum class ConsumePolicy : uint8_t { CHANNEL_ORDER, // consume all messages from a single channel before moving to the // next @@ -341,6 +366,7 @@ Node many_input_sink( } } } +} // namespace struct ManyInputSinkStreamingFanout : public StreamingFanout { void run(ConsumePolicy consume_policy) { @@ -372,9 +398,9 @@ struct ManyInputSinkStreamingFanout : public StreamingFanout { std::vector actual; actual.reserve(outs[c].size()); std::ranges::transform( - outs[c], std::back_inserter(actual), [](const Message& m) { - return m.get(); - } + outs[c], + std::back_inserter(actual), + [](const Message& m) { return m.get(); } ); EXPECT_EQ(expected, actual); } @@ -386,9 +412,9 @@ INSTANTIATE_TEST_SUITE_P( ManyInputSinkStreamingFanout, ::testing::Combine( ::testing::Values(FanoutPolicy::BOUNDED, FanoutPolicy::UNBOUNDED), - ::testing::Values(1, 2, 4), // number of threads - ::testing::Values(1, 2, 4), // number of output channels - ::testing::Values(1, 10, 100) // number of messages + ::testing::Values(1, 4), // number of threads + ::testing::Values(1, 4), // number of output channels + ::testing::Values(10, 100) // number of messages ), [](testing::TestParamInfo const& info) { return "policy_" + policy_to_string(std::get<0>(info.param)) + "_nthreads_" From 15ba8ed919eb064166179c164d4f3b8cfec1d3aa Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 13 Nov 2025 17:22:01 -0800 Subject: [PATCH 27/43] precommit Signed-off-by: niranda perera --- cpp/tests/streaming/test_fanout.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index f99744959..8f8d2d5e2 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -398,9 +398,9 @@ struct ManyInputSinkStreamingFanout : public StreamingFanout { std::vector actual; actual.reserve(outs[c].size()); std::ranges::transform( - outs[c], - std::back_inserter(actual), - [](const Message& m) { return m.get(); } + outs[c], std::back_inserter(actual), [](const Message& m) { + return m.get(); + } ); EXPECT_EQ(expected, actual); } From 5d1a305dc9878160c1b5c333dc0efa29eaae9c05 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 14 Nov 2025 12:01:29 -0800 Subject: [PATCH 28/43] minor improvement Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 739e4803c..200538749 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -284,14 +284,17 @@ Node unbounded_fo_process_input_task( state->ch_next_idx, [](size_t idx) { return idx != InvalidIdx; } ); - if (std::ranges::empty(filtered_view)) { + auto it = std::ranges::begin(filtered_view); // first valid idx + auto end = std::ranges::end(filtered_view); // end idx + + if (it == end) { // no valid indices, so all send tasks are in an invalid state return true; } - auto [min_val, max_val] = std::ranges::minmax(filtered_view); - last_completed_idx = min_val; - latest_processed_idx = max_val; + auto [min_it, max_it] = std::minmax_element(it, end); + last_completed_idx = *min_it; + latest_processed_idx = *max_it; return latest_processed_idx == state->recv_messages.size(); }); From f8833dbdb5fc81648afc83c6f373e533eb146603 Mon Sep 17 00:00:00 2001 From: Niranda Perera Date: Tue, 18 Nov 2025 08:08:55 -0800 Subject: [PATCH 29/43] Update cpp/src/streaming/core/fanout.cpp Co-authored-by: Mads R. B. Kristensen --- cpp/src/streaming/core/fanout.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 200538749..86dd5001c 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -308,7 +308,7 @@ Node unbounded_fo_process_input_task( // receive a message from the input channel auto msg = co_await ch_in->receive(); - { // relock mtx to update input_done/ recv_messages + { // relock mtx to update input_done/recv_messages auto lock = co_await state->mtx.scoped_lock(); if (msg.empty()) { state->input_done = true; From dca9378930c8e8090d0a56cad0574d26c53c445b Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 18 Nov 2025 12:06:34 -0800 Subject: [PATCH 30/43] addressing PR comments Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 85 +++++++++++++++++++------------ 1 file changed, 52 insertions(+), 33 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 200538749..bc047754e 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -121,8 +121,6 @@ struct UnboundedFanoutState { std::deque recv_messages; // next index to send for each channel std::vector ch_next_idx; - // index of the first message to purge - size_t purge_idx{0}; }; /** @@ -254,6 +252,44 @@ struct StateInputDoneAtExit { } }; +/** + * @brief Wait for a data request from the send tasks. + * + * @param state The state of the unbounded fanout. + * @param last_completed_idx The index of the last completed message. + * @param latest_processed_idx The index of the latest processed message. + * @return True if the state is valid and can move forward, false otherwise (all send + * tasks are in an invalid state). + */ +auto wait_for_data_request( + UnboundedFanoutState& state, size_t& last_completed_idx, size_t& latest_processed_idx +) -> coro::task { + auto lock = co_await state.mtx.scoped_lock(); + co_await state.request_data.wait(lock, [&] { + auto filtered_view = std::ranges::filter_view(state.ch_next_idx, [](size_t idx) { + return idx != InvalidIdx; + }); + + auto it = std::ranges::begin(filtered_view); // first valid idx + auto end = std::ranges::end(filtered_view); // end idx + + if (it == end) { + // no valid indices, so all send tasks are in an invalid state + return true; + } + + auto [min_it, max_it] = std::minmax_element(it, end); + last_completed_idx = *min_it; + latest_processed_idx = *max_it; + + return latest_processed_idx == state.recv_messages.size(); + }); + + // if both last_completed_idx and latest_processed_idx are invalid, it means that all + // send tasks are in an invalid state. + co_return !(last_completed_idx == InvalidIdx && latest_processed_idx == InvalidIdx); +} + /** * @brief Process input messages and notify send tasks to copy & send messages. * @@ -274,35 +310,18 @@ Node unbounded_fo_process_input_task( logger.trace("Scheduled process input task"); + // index of the first message to purge + size_t purge_idx = 0; + // input_done is only set by this task, so reading without lock is safe here while (!state->input_done) { size_t last_completed_idx = InvalidIdx, latest_processed_idx = InvalidIdx; - { - auto lock = co_await state->mtx.scoped_lock(); - co_await state->request_data.wait(lock, [&] { - auto filtered_view = std::ranges::filter_view( - state->ch_next_idx, [](size_t idx) { return idx != InvalidIdx; } - ); - - auto it = std::ranges::begin(filtered_view); // first valid idx - auto end = std::ranges::end(filtered_view); // end idx - - if (it == end) { - // no valid indices, so all send tasks are in an invalid state - return true; - } - - auto [min_it, max_it] = std::minmax_element(it, end); - last_completed_idx = *min_it; - latest_processed_idx = *max_it; - return latest_processed_idx == state->recv_messages.size(); - }); - - // all send tasks are in an invalid state, so we can break - if (last_completed_idx == InvalidIdx && latest_processed_idx == InvalidIdx) { - break; - } + if (!co_await wait_for_data_request( + *state, last_completed_idx, latest_processed_idx + )) + { + break; // waiting returned an invalid state, so we need to break } // receive a message from the input channel @@ -325,12 +344,12 @@ Node unbounded_fo_process_input_task( // that the indices are not invalidated. intentionally not locking the mtx // here, because we only need to know a lower-bound on the last completed idx // (ch_next_idx values are monotonically increasing) - while (state->purge_idx + 1 < last_completed_idx) { - state->recv_messages[state->purge_idx].reset(); - state->purge_idx++; + while (purge_idx + 1 < last_completed_idx) { + state->recv_messages[purge_idx].reset(); + purge_idx++; } logger.trace( - "recv_messages active size: ", state->recv_messages.size() - state->purge_idx + "recv_messages active size: ", state->recv_messages.size() - purge_idx ); } @@ -358,6 +377,8 @@ Node unbounded_fanout( std::shared_ptr ch_in, std::vector> chs_out ) { + auto& executor = *ctx->executor(); + ShutdownAtExit ch_in_shutdown{ch_in}; ShutdownAtExit chs_out_shutdown{chs_out}; co_await ctx->executor()->schedule(); @@ -367,8 +388,6 @@ Node unbounded_fanout( std::vector tasks; tasks.reserve(chs_out.size() + 1); - auto& executor = *ctx->executor(); - for (size_t i = 0; i < chs_out.size(); i++) { tasks.emplace_back(executor.schedule( unbounded_fo_send_task(*ctx, i, std::move(chs_out[i]), state) From 16fb1894ec3ac37ec3bd059ebc6bfa83cafc7849 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 19 Nov 2025 10:57:43 -0800 Subject: [PATCH 31/43] addressing comments Signed-off-by: niranda perera --- .../rapidsmpf/streaming/core/fanout.hpp | 5 +- cpp/src/streaming/core/fanout.cpp | 486 +++++++++--------- cpp/tests/streaming/test_fanout.cpp | 16 +- .../rapidsmpf/tests/streaming/test_fanout.py | 5 + 4 files changed, 271 insertions(+), 241 deletions(-) diff --git a/cpp/include/rapidsmpf/streaming/core/fanout.hpp b/cpp/include/rapidsmpf/streaming/core/fanout.hpp index 59a64cc6d..c15193272 100644 --- a/cpp/include/rapidsmpf/streaming/core/fanout.hpp +++ b/cpp/include/rapidsmpf/streaming/core/fanout.hpp @@ -49,12 +49,13 @@ enum class FanoutPolicy : uint8_t { * * @param ctx The node context to use. * @param ch_in Input channel from which messages are received. - * @param chs_out Output channels to which messages are broadcast. + * @param chs_out Output channels to which messages are broadcast. Must be at least 2. * @param policy The fanout strategy to use (see ::FanoutPolicy). * * @return Streaming node representing the fanout operation. * - * @throws std::invalid_argument If an unknown fanout policy is specified. + * @throws std::invalid_argument If an unknown fanout policy is specified or if the number + * of output channels is less than 2. */ Node fanout( std::shared_ptr ctx, diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 6f4f074be..465b295fb 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -104,258 +104,274 @@ Node bounded_fanout( } /** - * @brief State for the unbounded fanout. - */ -struct UnboundedFanoutState { - UnboundedFanoutState(size_t num_channels) : ch_next_idx(num_channels, 0) {} - - coro::mutex mtx; - // notify send tasks to copy & send messages - coro::condition_variable data_ready; - // notify this task to receive more data from the input channel - coro::condition_variable request_data; - // set to true when the input channel is fully consumed - bool input_done{false}; - // messages received from the input channel. We use a deque to avoid references being - // invalidated by reallocations. - std::deque recv_messages; - // next index to send for each channel - std::vector ch_next_idx; -}; - -/** - * @brief Sentinel value indicating that the index is invalid. This is set when a failure - * occurs during send tasks. process input task will filter out messages with this index. - */ -constexpr size_t InvalidIdx = std::numeric_limits::max(); - -/** - * @brief RAII helper class to set a channel index to invalid and notify the process input - * task to check if it should break. - */ -struct SetChannelIdxInvalidAtExit { - std::shared_ptr state; - size_t idx; - - ~SetChannelIdxInvalidAtExit() { - coro::sync_wait(set_channel_idx_invalid()); - } - - Node set_channel_idx_invalid() { - if (state) { - { - auto lock = co_await state->mtx.scoped_lock(); - state->ch_next_idx[idx] = InvalidIdx; - } - co_await state->request_data.notify_one(); - } - state.reset(); - } -}; - -/** - * @brief Send messages to multiple output channels. + * @brief Unbounded fanout implementation. + * + * The implementation follows a pull-based model, where the send tasks request data from + * the recv task. There is one recv task that receives messages from the input channel, + * and there are N send tasks that send messages to the output channels. + * + * Main task operation: + * - There is a shared deque of cached messages, and a vector that indicates the next + * index of the message to be sent to each output channel. + * - Recv task awaits until the number of cached messages is equal to the latest sent + * message index by any of the send tasks. This notifies the recv task to pull a message + * from the input channel, cache it, and notify the send tasks about the new messages. + * recv task continues this process until the input channel is fully consumed. + * - Each send task awaits until there are more cached messages to send. Once notified, it + * determines the current end of the cached messages, and sends messages in the range + * [next_idx, end_idx). Once these messages have been sent, it updates the next index to + * end_idx and notifies the recv task. + * + * Additional considerations: + * - In the recv task loop, it also identifies the last completed message index by all + * send tasks. Message upto this index are no longer needed, and are purged from the + * cached messages. + * - When a send task fails to send a message, this means the channel may have been + * prematurely shut down. In this case, it sets its index to InvalidIdx. Recv task will + * filter out channels with InvalidIdx. + * - There two RAII helpers to ensure that the notification mechanisms are properly + * cleaned up when the unbounded fanout state goes out of scope/ encounters an error. * - * @param ctx The context to use. - * @param idx The index of the task - * @param ch_out The output channel to send messages to. - * @param state The state of the unbounded fanout. - * @return A coroutine representing the task. */ -Node unbounded_fo_send_task( - Context& ctx, - size_t idx, - std::shared_ptr ch_out, - std::shared_ptr state -) { - ShutdownAtExit ch_shutdown{ch_out}; - SetChannelIdxInvalidAtExit set_ch_idx_invalid{.state = state, .idx = idx}; - co_await ctx.executor()->schedule(); - - auto& logger = ctx.logger(); - - size_t curr_recv_msg_sz = 0; // current size of the recv_messages deque - while (true) { - { - auto lock = co_await state->mtx.scoped_lock(); - co_await state->data_ready.wait(lock, [&] { - // irrespective of input_done, update the end_idx to the total number of - // messages - curr_recv_msg_sz = state->recv_messages.size(); - return state->input_done || state->ch_next_idx[idx] < curr_recv_msg_sz; - }); - if (state->input_done && state->ch_next_idx[idx] == curr_recv_msg_sz) { - // no more messages will be received, and all messages have been sent - break; - } +struct UnboundedFanout { + /** + * @brief Constructor. + * + * @param num_channels The number of output channels. + */ + explicit UnboundedFanout(size_t num_channels) : ch_next_idx(num_channels, 0) {} + + /** + * @brief Sentinel value indicating that the index is invalid. This is set when a + * failure occurs during send tasks. process input task will filter out messages with + * this index. + */ + static constexpr size_t InvalidIdx = std::numeric_limits::max(); + + /** + * @brief RAII helper class to set a channel index to invalid and notify the process + * input task to check if it should break. + */ + struct SetChannelIdxInvalidAtExit { + UnboundedFanout* fanout; + size_t idx; + + ~SetChannelIdxInvalidAtExit() { + coro::sync_wait(set_channel_idx_invalid()); } - // now we can copy & send messages in indices [next_idx, curr_recv_msg_sz) - // it is guaranteed that message purging will be done only on indices less than - // next_idx, so we can safely send messages without locking the mtx - for (size_t i = state->ch_next_idx[idx]; i < curr_recv_msg_sz; i++) { - auto const& msg = state->recv_messages[i]; - RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); - - // make reservations for each message so that it will fallback to host memory - // if needed - auto res = ctx.br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)); - if (!co_await ch_out->send(msg.copy(res))) { - // Failed to send message. Could be that the channel is shut down. - // So we need to abort the send task, and notify the process input task - co_await set_ch_idx_invalid.set_channel_idx_invalid(); - co_return; + Node set_channel_idx_invalid() { + if (idx != InvalidIdx) { + { + auto lock = co_await fanout->mtx.scoped_lock(); + fanout->ch_next_idx[idx] = InvalidIdx; + } + co_await fanout->request_data.notify_one(); } + idx = InvalidIdx; } - logger.trace( - "sent ", idx, " [", state->ch_next_idx[idx], ", ", curr_recv_msg_sz, ")" - ); + }; + + /** + * @brief Send messages to multiple output channels. + * + * @param ctx The context to use. + * @param self Self index of the task + * @param ch_out The output channel to send messages to. + * @param state The state of the unbounded fanout. + * @return A coroutine representing the task. + */ + Node send_task(Context& ctx, size_t self, std::shared_ptr ch_out) { + ShutdownAtExit ch_shutdown{ch_out}; + SetChannelIdxInvalidAtExit set_ch_idx_invalid{.fanout = this, .idx = self}; + co_await ctx.executor()->schedule(); + + auto& logger = ctx.logger(); + + size_t curr_recv_msg_sz = 0; // current size of the recv_messages deque + while (true) { + { + auto lock = co_await mtx.scoped_lock(); + co_await data_ready.wait(lock, [&] { + // irrespective of input_done, update the end_idx to the total number + // of messages + curr_recv_msg_sz = recv_messages.size(); + return input_done || ch_next_idx[self] < curr_recv_msg_sz; + }); + if (input_done && ch_next_idx[self] == curr_recv_msg_sz) { + // no more messages will be received, and all messages have been sent + break; + } + } - // now next_idx can be updated to end_idx, and if !input_done, we need to request - // parent task for more data - auto lock = co_await state->mtx.scoped_lock(); - state->ch_next_idx[idx] = curr_recv_msg_sz; - if (state->ch_next_idx[idx] == state->recv_messages.size()) { - if (state->input_done) { - // no more messages will be received, and all messages have been sent - break; - } else { - // request more data from the input channel - lock.unlock(); - co_await state->request_data.notify_one(); + // now we can copy & send messages in indices [next_idx, curr_recv_msg_sz) + // it is guaranteed that message purging will be done only on indices less + // than next_idx, so we can safely send messages without locking the mtx + for (size_t i = ch_next_idx[self]; i < curr_recv_msg_sz; i++) { + auto const& msg = recv_messages[i]; + RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); + + // make reservations for each message so that it will fallback to host + // memory if needed + auto res = + ctx.br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)); + if (!co_await ch_out->send(msg.copy(res))) { + // Failed to send message. Could be that the channel is shut down. + // So we need to abort the send task, and notify the process input + // task + co_await set_ch_idx_invalid.set_channel_idx_invalid(); + co_return; + } + } + logger.trace( + "sent ", self, " [", ch_next_idx[self], ", ", curr_recv_msg_sz, ")" + ); + + // now next_idx can be updated to end_idx, and if !input_done, we need to + // request the recv task for more data + auto lock = co_await mtx.scoped_lock(); + ch_next_idx[self] = curr_recv_msg_sz; + if (ch_next_idx[self] == recv_messages.size()) { + if (input_done) { + // no more messages will be received, and all messages have been sent + break; + } else { + // request more data from the input channel + lock.unlock(); + co_await request_data.notify_one(); + } } } - } - - co_await ch_out->drain(ctx.executor()); - logger.trace("Send task ", idx, " completed"); -} - -/** - * @brief RAII helper class to set input_done to true and notify all send tasks to wind - * down when the unbounded fanout state goes out of scope. - */ -struct StateInputDoneAtExit { - std::shared_ptr state; - ~StateInputDoneAtExit() { - coro::sync_wait(set_input_done()); + co_await ch_out->drain(ctx.executor()); + logger.trace("Send task ", self, " completed"); } - // forcibly set input_done to true and notify all send tasks to wind down - Node set_input_done() { - { - auto lock = co_await state->mtx.scoped_lock(); - state->input_done = true; - } - co_await state->data_ready.notify_all(); - } -}; - -/** - * @brief Wait for a data request from the send tasks. - * - * @param state The state of the unbounded fanout. - * @param last_completed_idx The index of the last completed message. - * @param latest_processed_idx The index of the latest processed message. - * @return True if the state is valid and can move forward, false otherwise (all send - * tasks are in an invalid state). - */ -auto wait_for_data_request( - UnboundedFanoutState& state, size_t& last_completed_idx, size_t& latest_processed_idx -) -> coro::task { - auto lock = co_await state.mtx.scoped_lock(); - co_await state.request_data.wait(lock, [&] { - auto filtered_view = std::ranges::filter_view(state.ch_next_idx, [](size_t idx) { - return idx != InvalidIdx; - }); + /** + * @brief RAII helper class to set input_done to true and notify all send tasks to + * wind down when the unbounded fanout state goes out of scope. + */ + struct SetInputDoneAtExit { + UnboundedFanout* fanout; - auto it = std::ranges::begin(filtered_view); // first valid idx - auto end = std::ranges::end(filtered_view); // end idx - - if (it == end) { - // no valid indices, so all send tasks are in an invalid state - return true; + ~SetInputDoneAtExit() { + coro::sync_wait(set_input_done()); } - auto [min_it, max_it] = std::minmax_element(it, end); - last_completed_idx = *min_it; - latest_processed_idx = *max_it; + // forcibly set input_done to true and notify all send tasks to wind down + Node set_input_done() { + { + auto lock = co_await fanout->mtx.scoped_lock(); + fanout->input_done = true; + } + co_await fanout->data_ready.notify_all(); + } + }; + + /** + * @brief Wait for a data request from the send tasks. + * + * @return The index of the last completed message and the index of the latest + * processed message. If both are InvalidIdx, it means that all send tasks are in an + * invalid state. + */ + auto wait_for_data_request() -> coro::task> { + size_t last_completed_idx = InvalidIdx; + size_t latest_processed_idx = InvalidIdx; + + auto lock = co_await mtx.scoped_lock(); + co_await request_data.wait(lock, [&] { + auto filtered_view = std::ranges::filter_view(ch_next_idx, [](size_t idx) { + return idx != InvalidIdx; + }); - return latest_processed_idx == state.recv_messages.size(); - }); + auto it = std::ranges::begin(filtered_view); // first valid idx + auto end = std::ranges::end(filtered_view); // end idx - // if both last_completed_idx and latest_processed_idx are invalid, it means that all - // send tasks are in an invalid state. - co_return !(last_completed_idx == InvalidIdx && latest_processed_idx == InvalidIdx); -} + if (it == end) { + // no valid indices, so all send tasks are in an invalid state + return true; + } -/** - * @brief Process input messages and notify send tasks to copy & send messages. - * - * @param ctx The context to use. - * @param ch_in The input channel to receive messages from. - * @param state The state of the unbounded fanout. - * @return A coroutine representing the task. - */ -Node unbounded_fo_process_input_task( - Context& ctx, - std::shared_ptr ch_in, - std::shared_ptr state -) { - ShutdownAtExit ch_in_shutdown{ch_in}; - StateInputDoneAtExit state_closer{state}; - co_await ctx.executor()->schedule(); - auto& logger = ctx.logger(); + auto [min_it, max_it] = std::minmax_element(it, end); + last_completed_idx = *min_it; + latest_processed_idx = *max_it; - logger.trace("Scheduled process input task"); + return latest_processed_idx == recv_messages.size(); + }); - // index of the first message to purge - size_t purge_idx = 0; + co_return std::make_pair(last_completed_idx, latest_processed_idx); + } - // input_done is only set by this task, so reading without lock is safe here - while (!state->input_done) { - size_t last_completed_idx = InvalidIdx, latest_processed_idx = InvalidIdx; + /** + * @brief Process input messages and notify send tasks to copy & send messages. + * + * @param ctx The context to use. + * @param ch_in The input channel to receive messages from. + * @return A coroutine representing the task. + */ + Node recv_task(Context& ctx, std::shared_ptr ch_in) { + ShutdownAtExit ch_in_shutdown{ch_in}; + SetInputDoneAtExit set_input_done{.fanout = this}; + co_await ctx.executor()->schedule(); + auto& logger = ctx.logger(); + + logger.trace("Scheduled process input task"); + + // index of the first message to purge + size_t purge_idx = 0; + + // input_done is only set by this task, so reading without lock is safe here + while (!input_done) { + auto [last_completed_idx, latest_processed_idx] = + co_await wait_for_data_request(); + if (last_completed_idx == InvalidIdx && latest_processed_idx == InvalidIdx) { + break; // all send tasks are in an invalid state, so we need to break + } - if (!co_await wait_for_data_request( - *state, last_completed_idx, latest_processed_idx - )) - { - break; // waiting returned an invalid state, so we need to break - } + // receive a message from the input channel + auto msg = co_await ch_in->receive(); - // receive a message from the input channel - auto msg = co_await ch_in->receive(); + { // relock mtx to update input_done/recv_messages + auto lock = co_await mtx.scoped_lock(); + if (msg.empty()) { + input_done = true; + } else { + recv_messages.emplace_back(std::move(msg)); + } + } - { // relock mtx to update input_done/recv_messages - auto lock = co_await state->mtx.scoped_lock(); - if (msg.empty()) { - state->input_done = true; - } else { - state->recv_messages.emplace_back(std::move(msg)); + // notify send_tasks to copy & send messages + co_await data_ready.notify_all(); + + // purge completed send_tasks. This will reset the messages to empty, so that + // they release the memory, however the deque is not resized. This guarantees + // that the indices are not invalidated. intentionally not locking the mtx + // here, because we only need to know a lower-bound on the last completed idx + // (ch_next_idx values are monotonically increasing) + while (purge_idx + 1 < last_completed_idx) { + recv_messages[purge_idx].reset(); + purge_idx++; } + logger.trace("recv_messages active size: ", recv_messages.size() - purge_idx); } - // notify send_tasks to copy & send messages - co_await state->data_ready.notify_all(); - - // purge completed send_tasks. This will reset the messages to empty, so that - // they release the memory, however the deque is not resized. This guarantees - // that the indices are not invalidated. intentionally not locking the mtx - // here, because we only need to know a lower-bound on the last completed idx - // (ch_next_idx values are monotonically increasing) - while (purge_idx + 1 < last_completed_idx) { - state->recv_messages[purge_idx].reset(); - purge_idx++; - } - logger.trace( - "recv_messages active size: ", state->recv_messages.size() - purge_idx - ); + co_await ch_in->drain(ctx.executor()); + logger.trace("Process input task completed"); } - co_await ch_in->drain(ctx.executor()); - logger.trace("Process input task completed"); -} + coro::mutex mtx; ///< notify send tasks to copy & send messages + coro::condition_variable + data_ready; ///< notify send tasks to copy & send messages notify this task to + ///< receive more data from the input channel + coro::condition_variable + request_data; ///< notify recv task to receive more data from the input channel + bool input_done{false}; ///< set to true when the input channel is fully consumed + std::deque + recv_messages; ///< messages received from the input channel. We use a deque to + ///< avoid references being invalidated by reallocations. + std::vector ch_next_idx; ///< next index to send for each channel +}; /** * @brief Broadcast messages from one input channel to multiple output channels. @@ -383,19 +399,17 @@ Node unbounded_fanout( ShutdownAtExit chs_out_shutdown{chs_out}; co_await ctx->executor()->schedule(); auto& logger = ctx->logger(); - auto state = std::make_shared(chs_out.size()); + auto fanout = std::make_unique(chs_out.size()); std::vector tasks; tasks.reserve(chs_out.size() + 1); for (size_t i = 0; i < chs_out.size(); i++) { - tasks.emplace_back(executor.schedule( - unbounded_fo_send_task(*ctx, i, std::move(chs_out[i]), state) - )); + tasks.emplace_back( + executor.schedule(fanout->send_task(*ctx, i, std::move(chs_out[i]))) + ); } - tasks.emplace_back(executor.schedule( - unbounded_fo_process_input_task(*ctx, std::move(ch_in), std::move(state)) - )); + tasks.emplace_back(executor.schedule(fanout->recv_task(*ctx, std::move(ch_in)))); coro_results(co_await coro::when_all(std::move(tasks))); logger.debug("Unbounded fanout completed"); @@ -409,13 +423,11 @@ Node fanout( std::vector> chs_out, FanoutPolicy policy ) { - RAPIDSMPF_EXPECTS(!chs_out.empty(), "output channels cannot be empty"); - - // if there is only one output channel, both bounded and unbounded implementations are - // semantically equivalent. So we can use the bounded fanout implementation. - if (chs_out.size() == 1) { - return bounded_fanout(std::move(ctx), std::move(ch_in), std::move(chs_out)); - } + RAPIDSMPF_EXPECTS( + chs_out.size() > 1, + "fanout requires at least 2 output channels", + std::invalid_argument + ); switch (policy) { case FanoutPolicy::BOUNDED: diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index 8f8d2d5e2..fb29f0dd2 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -58,6 +58,18 @@ std::string policy_to_string(FanoutPolicy policy) { } } +using BaseStreamingFanout = BaseStreamingFixture; + +TEST_F(BaseStreamingFanout, InvalidNumberOfOutputChannels) { + auto in = ctx->create_channel(); + std::vector> out_chs; + out_chs.push_back(ctx->create_channel()); + EXPECT_THROW( + std::ignore = node::fanout(ctx, in, out_chs, FanoutPolicy::BOUNDED), + std::invalid_argument + ); +} + class StreamingFanout : public BaseStreamingFixture, public ::testing::WithParamInterface> { @@ -84,7 +96,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::Values(FanoutPolicy::BOUNDED, FanoutPolicy::UNBOUNDED), ::testing::Values(1, 4), // number of threads - ::testing::Values(1, 4), // number of output channels + ::testing::Values(2, 4), // number of output channels ::testing::Values(10, 100) // number of messages ), [](testing::TestParamInfo const& info) { @@ -413,7 +425,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::Values(FanoutPolicy::BOUNDED, FanoutPolicy::UNBOUNDED), ::testing::Values(1, 4), // number of threads - ::testing::Values(1, 4), // number of output channels + ::testing::Values(2, 4), // number of output channels ::testing::Values(10, 100) // number of messages ), [](testing::TestParamInfo const& info) { diff --git a/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py b/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py index 8c38ff287..8fe54bb9c 100644 --- a/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py +++ b/python/rapidsmpf/rapidsmpf/tests/streaming/test_fanout.py @@ -97,6 +97,11 @@ def test_fanout_multiple_outputs( context.create_channel() for _ in range(num_outputs) ] + if num_outputs == 1: + with pytest.raises(ValueError): + fanout(context, ch_in, chs_out, policy) + return + # Create test messages messages = [] for i in range(3): From 1cfa368bfd8398133bf05de7601fb6b2ec0068c4 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 19 Nov 2025 11:00:43 -0800 Subject: [PATCH 32/43] minor change Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 465b295fb..ba6635f17 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -178,7 +178,6 @@ struct UnboundedFanout { * @param ctx The context to use. * @param self Self index of the task * @param ch_out The output channel to send messages to. - * @param state The state of the unbounded fanout. * @return A coroutine representing the task. */ Node send_task(Context& ctx, size_t self, std::shared_ptr ch_out) { From 9e213dbb72c4b240ce57addea0a9fa8b2880c228 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 19 Nov 2025 11:46:20 -0800 Subject: [PATCH 33/43] API changes Signed-off-by: niranda perera --- cpp/tests/streaming/test_fanout.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tests/streaming/test_fanout.cpp b/cpp/tests/streaming/test_fanout.cpp index fb29f0dd2..74a2420be 100644 --- a/cpp/tests/streaming/test_fanout.cpp +++ b/cpp/tests/streaming/test_fanout.cpp @@ -10,7 +10,7 @@ #include -#include +#include #include #include #include From 0ad1b4480a02c36d8b4706bec904b96270ac80b9 Mon Sep 17 00:00:00 2001 From: Niranda Perera Date: Thu, 20 Nov 2025 08:47:13 -0800 Subject: [PATCH 34/43] Update cpp/src/streaming/core/fanout.cpp Co-authored-by: Lawrence Mitchell --- cpp/src/streaming/core/fanout.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index ba6635f17..2a877414e 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -90,6 +90,10 @@ Node bounded_fanout( break; } + std::erase_if(chs_out, [](auto&& ch) { return ch->is_shutdown(); }); + if (chs_out.empty()) { + break; + } co_await send_to_channels(ctx.get(), std::move(msg), chs_out); logger.trace("Sent message ", msg.sequence_number()); } From aed249e5b8c15b60ed1b98025421703c215e56ee Mon Sep 17 00:00:00 2001 From: Niranda Perera Date: Thu, 20 Nov 2025 08:48:49 -0800 Subject: [PATCH 35/43] Update cpp/src/streaming/core/fanout.cpp Co-authored-by: Lawrence Mitchell --- cpp/src/streaming/core/fanout.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 2a877414e..b329d62ec 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -288,8 +288,8 @@ struct UnboundedFanout { return idx != InvalidIdx; }); - auto it = std::ranges::begin(filtered_view); // first valid idx - auto end = std::ranges::end(filtered_view); // end idx + auto it = std::ranges::begin(filtered_view); + auto end = std::ranges::end(filtered_view); if (it == end) { // no valid indices, so all send tasks are in an invalid state From 6147d63022f4cb6178a7c215df023abc6411a9b3 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 20 Nov 2025 12:46:36 -0800 Subject: [PATCH 36/43] addressing PR comments Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 204 +++++++++++++++--------------- 1 file changed, 100 insertions(+), 104 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 2a877414e..7a863fc8f 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -44,24 +44,33 @@ constexpr std::span try_memory_types(Message const& msg) { * @param chs_out The set of output channels to which the message is sent. */ Node send_to_channels( - Context* ctx, Message&& msg, std::vector>& chs_out + Context& ctx, Message&& msg, std::vector>& chs_out ) { RAPIDSMPF_EXPECTS(!chs_out.empty(), "output channels cannot be empty"); - std::vector> tasks; - tasks.reserve(chs_out.size()); + auto async_copy_and_send = [](Context& ctx_, + Message const& msg_, + size_t msg_sz_, + Channel& ch_) -> coro::task { + co_await ctx_.executor()->schedule(); + auto res = ctx_.br()->reserve_or_fail(msg_sz_, try_memory_types(msg_)); + co_return co_await ch_.send(msg_.copy(res)); + }; + + // async copy & send tasks for all channels except the last one + std::vector> async_send_tasks; + async_send_tasks.reserve(chs_out.size() - 1); + size_t msg_sz = msg.copy_cost(); for (size_t i = 0; i < chs_out.size() - 1; i++) { - // do a reservation for each copy, so that it will fallback to host memory if - // needed - auto res = ctx->br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)); - tasks.emplace_back(chs_out[i]->send(msg.copy(res))); + async_send_tasks.emplace_back(async_copy_and_send(ctx, msg, msg_sz, *chs_out[i])); } - // move the message to the last channel to avoid extra copy - tasks.emplace_back(chs_out.back()->send(std::move(msg))); // note that the send tasks may return false if the channel is shut down. But we can // safely ignore this in bounded fanout. - coro_results(co_await coro::when_all(std::move(tasks))); + coro_results(co_await coro::when_all(std::move(async_send_tasks))); + + // move the message to the last channel to avoid extra copy + co_await chs_out.back()->send(std::move(msg)); } /** @@ -82,7 +91,6 @@ Node bounded_fanout( ) { ShutdownAtExit c1{ch_in}; ShutdownAtExit c2{chs_out}; - auto& logger = ctx->logger(); co_await ctx->executor()->schedule(); while (true) { auto msg = co_await ch_in->receive(); @@ -90,12 +98,13 @@ Node bounded_fanout( break; } + // filter out shut down channels to avoid making unnecessary copies std::erase_if(chs_out, [](auto&& ch) { return ch->is_shutdown(); }); if (chs_out.empty()) { + // all channels are shut down, so we can break & shutdown the input channel break; } - co_await send_to_channels(ctx.get(), std::move(msg), chs_out); - logger.trace("Sent message ", msg.sequence_number()); + co_await send_to_channels(*ctx, std::move(msg), chs_out); } std::vector drain_tasks; @@ -104,7 +113,6 @@ Node bounded_fanout( drain_tasks.emplace_back(ch->drain(ctx->executor())); } coro_results(co_await coro::when_all(std::move(drain_tasks))); - logger.trace("Completed bounded fanout"); } /** @@ -117,22 +125,28 @@ Node bounded_fanout( * Main task operation: * - There is a shared deque of cached messages, and a vector that indicates the next * index of the message to be sent to each output channel. - * - Recv task awaits until the number of cached messages is equal to the latest sent - * message index by any of the send tasks. This notifies the recv task to pull a message - * from the input channel, cache it, and notify the send tasks about the new messages. - * recv task continues this process until the input channel is fully consumed. - * - Each send task awaits until there are more cached messages to send. Once notified, it - * determines the current end of the cached messages, and sends messages in the range - * [next_idx, end_idx). Once these messages have been sent, it updates the next index to - * end_idx and notifies the recv task. + * - All shared resources are protected by a mutex. There are two condition variables + * where: + * - recv task notifies send tasks when new messages are cached + * - send tasks notify recv task when they have completed sending messages + * - Recv task awaits until the number of cached messages at least one send task has + * completed sending all the cached messages. It will then pull a message from the input + * channel, cache it, and notify the send tasks about the new messages. recv task + * continues this process until the input channel is fully consumed. + * - Each send task awaits until there are more cached messages to send. When the new + * messages available noitification is received, it will continue to copy and send cached + * messages, starting from the index of the last sent message, to the end of the cached + * messages (as it last observed). Then it updates the last completed message index and + * notifies the recv task. This process continues until the recv task notifies that the + * input channel is fully consumed. * * Additional considerations: - * - In the recv task loop, it also identifies the last completed message index by all - * send tasks. Message upto this index are no longer needed, and are purged from the - * cached messages. + * - In the recv task loop, it also identifies the lowest completed message index by all + * send tasks. Message upto this index are no longer needed, and are released from the + * cached messages deque. * - When a send task fails to send a message, this means the channel may have been - * prematurely shut down. In this case, it sets its index to InvalidIdx. Recv task will - * filter out channels with InvalidIdx. + * prematurely shut down. In this case, it sets a sential value to mark it as invalid. + * Recv task will filter out channels with the invalid sentinel value. * - There two RAII helpers to ensure that the notification mechanisms are properly * cleaned up when the unbounded fanout state goes out of scope/ encounters an error. * @@ -143,7 +157,7 @@ struct UnboundedFanout { * * @param num_channels The number of output channels. */ - explicit UnboundedFanout(size_t num_channels) : ch_next_idx(num_channels, 0) {} + explicit UnboundedFanout(size_t num_channels) : per_ch_processed(num_channels, 0) {} /** * @brief Sentinel value indicating that the index is invalid. This is set when a @@ -158,21 +172,20 @@ struct UnboundedFanout { */ struct SetChannelIdxInvalidAtExit { UnboundedFanout* fanout; - size_t idx; + size_t& self_next_idx; ~SetChannelIdxInvalidAtExit() { coro::sync_wait(set_channel_idx_invalid()); } Node set_channel_idx_invalid() { - if (idx != InvalidIdx) { + if (self_next_idx != InvalidIdx) { { auto lock = co_await fanout->mtx.scoped_lock(); - fanout->ch_next_idx[idx] = InvalidIdx; + self_next_idx = InvalidIdx; } co_await fanout->request_data.notify_one(); } - idx = InvalidIdx; } }; @@ -180,28 +193,28 @@ struct UnboundedFanout { * @brief Send messages to multiple output channels. * * @param ctx The context to use. - * @param self Self index of the task + * @param self_next_idx Next index to send for the current channel * @param ch_out The output channel to send messages to. * @return A coroutine representing the task. */ - Node send_task(Context& ctx, size_t self, std::shared_ptr ch_out) { + Node send_task(Context& ctx, size_t& self_next_idx, std::shared_ptr ch_out) { ShutdownAtExit ch_shutdown{ch_out}; - SetChannelIdxInvalidAtExit set_ch_idx_invalid{.fanout = this, .idx = self}; + SetChannelIdxInvalidAtExit set_ch_idx_invalid{ + .fanout = this, .self_next_idx = self_next_idx + }; co_await ctx.executor()->schedule(); - auto& logger = ctx.logger(); - - size_t curr_recv_msg_sz = 0; // current size of the recv_messages deque + size_t n_available_messages = 0; // number of messages available to send while (true) { { auto lock = co_await mtx.scoped_lock(); co_await data_ready.wait(lock, [&] { - // irrespective of input_done, update the end_idx to the total number - // of messages - curr_recv_msg_sz = recv_messages.size(); - return input_done || ch_next_idx[self] < curr_recv_msg_sz; + // irrespective of no_more_input, update the end_idx to the total + // number of messages + n_available_messages = recv_messages.size(); + return no_more_input || self_next_idx < n_available_messages; }); - if (input_done && ch_next_idx[self] == curr_recv_msg_sz) { + if (no_more_input && self_next_idx == n_available_messages) { // no more messages will be received, and all messages have been sent break; } @@ -210,12 +223,10 @@ struct UnboundedFanout { // now we can copy & send messages in indices [next_idx, curr_recv_msg_sz) // it is guaranteed that message purging will be done only on indices less // than next_idx, so we can safely send messages without locking the mtx - for (size_t i = ch_next_idx[self]; i < curr_recv_msg_sz; i++) { + for (size_t i = self_next_idx; i < n_available_messages; i++) { auto const& msg = recv_messages[i]; RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); - // make reservations for each message so that it will fallback to host - // memory if needed auto res = ctx.br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)); if (!co_await ch_out->send(msg.copy(res))) { @@ -226,16 +237,13 @@ struct UnboundedFanout { co_return; } } - logger.trace( - "sent ", self, " [", ch_next_idx[self], ", ", curr_recv_msg_sz, ")" - ); - // now next_idx can be updated to end_idx, and if !input_done, we need to + // now next_idx can be updated to end_idx, and if !no_more_input, we need to // request the recv task for more data auto lock = co_await mtx.scoped_lock(); - ch_next_idx[self] = curr_recv_msg_sz; - if (ch_next_idx[self] == recv_messages.size()) { - if (input_done) { + self_next_idx = n_available_messages; + if (self_next_idx == recv_messages.size()) { + if (no_more_input) { // no more messages will be received, and all messages have been sent break; } else { @@ -247,11 +255,10 @@ struct UnboundedFanout { } co_await ch_out->drain(ctx.executor()); - logger.trace("Send task ", self, " completed"); } /** - * @brief RAII helper class to set input_done to true and notify all send tasks to + * @brief RAII helper class to set no_more_input to true and notify all send tasks to * wind down when the unbounded fanout state goes out of scope. */ struct SetInputDoneAtExit { @@ -261,11 +268,11 @@ struct UnboundedFanout { coro::sync_wait(set_input_done()); } - // forcibly set input_done to true and notify all send tasks to wind down + // forcibly set no_more_input to true and notify all send tasks to wind down Node set_input_done() { { auto lock = co_await fanout->mtx.scoped_lock(); - fanout->input_done = true; + fanout->no_more_input = true; } co_await fanout->data_ready.notify_all(); } @@ -279,17 +286,17 @@ struct UnboundedFanout { * invalid state. */ auto wait_for_data_request() -> coro::task> { - size_t last_completed_idx = InvalidIdx; + size_t lowest_completed_idx = InvalidIdx; size_t latest_processed_idx = InvalidIdx; auto lock = co_await mtx.scoped_lock(); co_await request_data.wait(lock, [&] { - auto filtered_view = std::ranges::filter_view(ch_next_idx, [](size_t idx) { - return idx != InvalidIdx; - }); + auto filtered_view = std::ranges::filter_view( + per_ch_processed, [](size_t idx) { return idx != InvalidIdx; } + ); - auto it = std::ranges::begin(filtered_view); // first valid idx - auto end = std::ranges::end(filtered_view); // end idx + auto it = std::ranges::begin(filtered_view); // advance to first valid idx + auto end = std::ranges::end(filtered_view); if (it == end) { // no valid indices, so all send tasks are in an invalid state @@ -297,13 +304,13 @@ struct UnboundedFanout { } auto [min_it, max_it] = std::minmax_element(it, end); - last_completed_idx = *min_it; + lowest_completed_idx = *min_it; latest_processed_idx = *max_it; return latest_processed_idx == recv_messages.size(); }); - co_return std::make_pair(last_completed_idx, latest_processed_idx); + co_return std::make_pair(lowest_completed_idx, latest_processed_idx); } /** @@ -317,28 +324,26 @@ struct UnboundedFanout { ShutdownAtExit ch_in_shutdown{ch_in}; SetInputDoneAtExit set_input_done{.fanout = this}; co_await ctx.executor()->schedule(); - auto& logger = ctx.logger(); - - logger.trace("Scheduled process input task"); // index of the first message to purge size_t purge_idx = 0; - // input_done is only set by this task, so reading without lock is safe here - while (!input_done) { - auto [last_completed_idx, latest_processed_idx] = + // no_more_input is only set by this task, so reading without lock is safe here + while (!no_more_input) { + auto [lowest_completed_idx, latest_processed_idx] = co_await wait_for_data_request(); - if (last_completed_idx == InvalidIdx && latest_processed_idx == InvalidIdx) { - break; // all send tasks are in an invalid state, so we need to break + if (lowest_completed_idx == InvalidIdx && latest_processed_idx == InvalidIdx) + { + break; } // receive a message from the input channel auto msg = co_await ch_in->receive(); - { // relock mtx to update input_done/recv_messages + { auto lock = co_await mtx.scoped_lock(); if (msg.empty()) { - input_done = true; + no_more_input = true; } else { recv_messages.emplace_back(std::move(msg)); } @@ -347,44 +352,37 @@ struct UnboundedFanout { // notify send_tasks to copy & send messages co_await data_ready.notify_all(); - // purge completed send_tasks. This will reset the messages to empty, so that - // they release the memory, however the deque is not resized. This guarantees - // that the indices are not invalidated. intentionally not locking the mtx - // here, because we only need to know a lower-bound on the last completed idx - // (ch_next_idx values are monotonically increasing) - while (purge_idx + 1 < last_completed_idx) { + // Reset messages that are no longer needed, so that they release the memory, + // however the deque is not resized. This guarantees that the indices are not + // invalidated. + while (purge_idx + 1 < lowest_completed_idx) { recv_messages[purge_idx].reset(); purge_idx++; } - logger.trace("recv_messages active size: ", recv_messages.size() - purge_idx); } co_await ch_in->drain(ctx.executor()); - logger.trace("Process input task completed"); } coro::mutex mtx; ///< notify send tasks to copy & send messages coro::condition_variable - data_ready; ///< notify send tasks to copy & send messages notify this task to - ///< receive more data from the input channel - coro::condition_variable - request_data; ///< notify recv task to receive more data from the input channel - bool input_done{false}; ///< set to true when the input channel is fully consumed + data_ready; ///< recv task notifies send tasks to copy & send messages + coro::condition_variable request_data; ///< send tasks notify recv task to pull more + ///< data from the input channel + bool no_more_input{false}; ///< set to true when the input channel is fully consumed std::deque - recv_messages; ///< messages received from the input channel. We use a deque to - ///< avoid references being invalidated by reallocations. - std::vector ch_next_idx; ///< next index to send for each channel + recv_messages; ///< messages received from the input channel. Using a deque to + ///< avoid invalidating references by reallocations. + std::vector + per_ch_processed; ///< number of messages processed for each channel (ie. next + ///< index to send for each channel) }; /** * @brief Broadcast messages from one input channel to multiple output channels. * - * This is an all-purpose implementation that can support consuming messages by the - * channel order or message order. Output channels could be connected to - * single/multiple consumer nodes. A consumer node can decide to consume all messages - * from a single channel before moving to the next channel, or it can consume messages - * from all channels before moving to the next message. When a message has been sent - * to all output channels, it is purged from the internal deque. + * In contrast to `bounded_fanout`, an unbounded fanout supports arbitrary + * consumption orders of the output channels. * * @param ctx The context to use. * @param ch_in The input channel to receive messages from. @@ -401,21 +399,19 @@ Node unbounded_fanout( ShutdownAtExit ch_in_shutdown{ch_in}; ShutdownAtExit chs_out_shutdown{chs_out}; co_await ctx->executor()->schedule(); - auto& logger = ctx->logger(); - auto fanout = std::make_unique(chs_out.size()); + UnboundedFanout fanout(chs_out.size()); std::vector tasks; tasks.reserve(chs_out.size() + 1); for (size_t i = 0; i < chs_out.size(); i++) { - tasks.emplace_back( - executor.schedule(fanout->send_task(*ctx, i, std::move(chs_out[i]))) - ); + tasks.emplace_back(executor.schedule( + fanout.send_task(*ctx, fanout.per_ch_processed[i], std::move(chs_out[i])) + )); } - tasks.emplace_back(executor.schedule(fanout->recv_task(*ctx, std::move(ch_in)))); + tasks.emplace_back(executor.schedule(fanout.recv_task(*ctx, std::move(ch_in)))); coro_results(co_await coro::when_all(std::move(tasks))); - logger.debug("Unbounded fanout completed"); } } // namespace From c7740c8e55294a111b5b00d598b03ffb3bae1a65 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 20 Nov 2025 14:14:07 -0800 Subject: [PATCH 37/43] addressing comments Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 36 +++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 7a863fc8f..2c4e27af1 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -48,10 +48,9 @@ Node send_to_channels( ) { RAPIDSMPF_EXPECTS(!chs_out.empty(), "output channels cannot be empty"); - auto async_copy_and_send = [](Context& ctx_, - Message const& msg_, - size_t msg_sz_, - Channel& ch_) -> coro::task { + auto async_copy_and_send = + [](Context& ctx_, Message const& msg_, size_t msg_sz_, Channel& ch_ + ) -> coro::task { co_await ctx_.executor()->schedule(); auto res = ctx_.br()->reserve_or_fail(msg_sz_, try_memory_types(msg_)); co_return co_await ch_.send(msg_.copy(res)); @@ -281,13 +280,14 @@ struct UnboundedFanout { /** * @brief Wait for a data request from the send tasks. * - * @return The index of the last completed message and the index of the latest - * processed message. If both are InvalidIdx, it means that all send tasks are in an - * invalid state. + * @return A minmax pair of `per_ch_processed` values. min is index of the last + * completed message index + 1 and max is the index of the latest processed message + * index + 1. If both are InvalidIdx, it means that all send tasks are in an invalid + * state. */ auto wait_for_data_request() -> coro::task> { - size_t lowest_completed_idx = InvalidIdx; - size_t latest_processed_idx = InvalidIdx; + size_t per_ch_processed_min = InvalidIdx; + size_t per_ch_processed_max = InvalidIdx; auto lock = co_await mtx.scoped_lock(); co_await request_data.wait(lock, [&] { @@ -304,13 +304,13 @@ struct UnboundedFanout { } auto [min_it, max_it] = std::minmax_element(it, end); - lowest_completed_idx = *min_it; - latest_processed_idx = *max_it; + per_ch_processed_min = *min_it; + per_ch_processed_max = *max_it; - return latest_processed_idx == recv_messages.size(); + return per_ch_processed_max == recv_messages.size(); }); - co_return std::make_pair(lowest_completed_idx, latest_processed_idx); + co_return std::make_pair(per_ch_processed_min, per_ch_processed_max); } /** @@ -330,9 +330,9 @@ struct UnboundedFanout { // no_more_input is only set by this task, so reading without lock is safe here while (!no_more_input) { - auto [lowest_completed_idx, latest_processed_idx] = + auto [per_ch_processed_min, per_ch_processed_max] = co_await wait_for_data_request(); - if (lowest_completed_idx == InvalidIdx && latest_processed_idx == InvalidIdx) + if (per_ch_processed_min == InvalidIdx && per_ch_processed_max == InvalidIdx) { break; } @@ -352,10 +352,10 @@ struct UnboundedFanout { // notify send_tasks to copy & send messages co_await data_ready.notify_all(); - // Reset messages that are no longer needed, so that they release the memory, - // however the deque is not resized. This guarantees that the indices are not + // Reset messages that are no longer needed, so that they release the memory. + // However the deque is not resized. This guarantees that the indices are not // invalidated. - while (purge_idx + 1 < lowest_completed_idx) { + while (purge_idx < per_ch_processed_min) { recv_messages[purge_idx].reset(); purge_idx++; } From 0fef6216767b73cfc251827c1cbd48289e130c30 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 21 Nov 2025 07:12:53 -0800 Subject: [PATCH 38/43] stashing messages using ref wrappers Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 2c4e27af1..5e76d1aee 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -48,9 +48,10 @@ Node send_to_channels( ) { RAPIDSMPF_EXPECTS(!chs_out.empty(), "output channels cannot be empty"); - auto async_copy_and_send = - [](Context& ctx_, Message const& msg_, size_t msg_sz_, Channel& ch_ - ) -> coro::task { + auto async_copy_and_send = [](Context& ctx_, + Message const& msg_, + size_t msg_sz_, + Channel& ch_) -> coro::task { co_await ctx_.executor()->schedule(); auto res = ctx_.br()->reserve_or_fail(msg_sz_, try_memory_types(msg_)); co_return co_await ch_.send(msg_.copy(res)); @@ -204,6 +205,7 @@ struct UnboundedFanout { co_await ctx.executor()->schedule(); size_t n_available_messages = 0; // number of messages available to send + std::vector> messages_to_send; while (true) { { auto lock = co_await mtx.scoped_lock(); @@ -217,18 +219,20 @@ struct UnboundedFanout { // no more messages will be received, and all messages have been sent break; } + // stash msg references under the lock + messages_to_send.reserve(n_available_messages - self_next_idx); + for (size_t i = self_next_idx; i < n_available_messages; i++) { + messages_to_send.emplace_back(recv_messages[i]); + } } - // now we can copy & send messages in indices [next_idx, curr_recv_msg_sz) - // it is guaranteed that message purging will be done only on indices less - // than next_idx, so we can safely send messages without locking the mtx - for (size_t i = self_next_idx; i < n_available_messages; i++) { - auto const& msg = recv_messages[i]; - RAPIDSMPF_EXPECTS(!msg.empty(), "message cannot be empty"); + for (auto const& msg : messages_to_send) { + RAPIDSMPF_EXPECTS(!msg.get().empty(), "message cannot be empty"); - auto res = - ctx.br()->reserve_or_fail(msg.copy_cost(), try_memory_types(msg)); - if (!co_await ch_out->send(msg.copy(res))) { + auto res = ctx.br()->reserve_or_fail( + msg.get().copy_cost(), try_memory_types(msg.get()) + ); + if (!co_await ch_out->send(msg.get().copy(res))) { // Failed to send message. Could be that the channel is shut down. // So we need to abort the send task, and notify the process input // task @@ -236,6 +240,7 @@ struct UnboundedFanout { co_return; } } + messages_to_send.clear(); // now next_idx can be updated to end_idx, and if !no_more_input, we need to // request the recv task for more data From 6bb0bd81a7559abb677b51a2fffab42185cba4ec Mon Sep 17 00:00:00 2001 From: Niranda Perera Date: Mon, 24 Nov 2025 07:10:17 -0800 Subject: [PATCH 39/43] Update cpp/src/streaming/core/fanout.cpp Co-authored-by: Lawrence Mitchell --- cpp/src/streaming/core/fanout.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 5e76d1aee..676ed1d13 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -204,7 +204,7 @@ struct UnboundedFanout { }; co_await ctx.executor()->schedule(); - size_t n_available_messages = 0; // number of messages available to send + size_t n_available_messages = 0; std::vector> messages_to_send; while (true) { { From 4a66c0fdc760c929f571e708975f059fee469b8b Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 25 Nov 2025 11:39:13 -0800 Subject: [PATCH 40/43] addressing PR comments Signed-off-by: niranda perera --- cpp/src/streaming/core/fanout.cpp | 32 +++++++++++-------- .../rapidsmpf/streaming/core/fanout.pyx | 12 +++---- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 676ed1d13..53bb3e628 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -147,7 +147,7 @@ Node bounded_fanout( * - When a send task fails to send a message, this means the channel may have been * prematurely shut down. In this case, it sets a sential value to mark it as invalid. * Recv task will filter out channels with the invalid sentinel value. - * - There two RAII helpers to ensure that the notification mechanisms are properly + * - There are two RAII helpers to ensure that the notification mechanisms are properly * cleaned up when the unbounded fanout state goes out of scope/ encounters an error. * */ @@ -369,18 +369,24 @@ struct UnboundedFanout { co_await ch_in->drain(ctx.executor()); } - coro::mutex mtx; ///< notify send tasks to copy & send messages - coro::condition_variable - data_ready; ///< recv task notifies send tasks to copy & send messages - coro::condition_variable request_data; ///< send tasks notify recv task to pull more - ///< data from the input channel - bool no_more_input{false}; ///< set to true when the input channel is fully consumed - std::deque - recv_messages; ///< messages received from the input channel. Using a deque to - ///< avoid invalidating references by reallocations. - std::vector - per_ch_processed; ///< number of messages processed for each channel (ie. next - ///< index to send for each channel) + coro::mutex mtx; + + /// @brief recv task notifies send tasks to copy & send messages + coro::condition_variable data_ready; + + /// @brief send tasks notify recv task to pull more data from the input channel + coro::condition_variable request_data; + + /// @brief set to true when the input channel is fully consumed + bool no_more_input{false}; + + /// @brief messages received from the input channel. Using a deque to avoid + /// invalidating references by reallocations. + std::deque recv_messages; + + /// @brief number of messages processed for each channel (ie. next index to send for + /// each channel) + std::vector per_ch_processed; }; /** diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx index 159022f8f..ed786e43f 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx @@ -11,7 +11,8 @@ from rapidsmpf.streaming.core.fanout cimport FanoutPolicy from rapidsmpf.streaming.core.node cimport CppNode, cpp_Node -def fanout(Context ctx, Channel ch_in, list chs_out, FanoutPolicy policy): +def fanout(Context ctx, Channel ch_in, list[Channel] chs_out, FanoutPolicy policy) \ + -> CppNode: """ Broadcast messages from one input channel to multiple output channels. @@ -24,18 +25,17 @@ def fanout(Context ctx, Channel ch_in, list chs_out, FanoutPolicy policy): Parameters ---------- - ctx : Context + ctx : The node context to use. - ch_in : Channel + ch_in : Input channel from which messages are received. - chs_out : list[Channel] + chs_out : Output channels to which messages are broadcast. - policy : FanoutPolicy + policy : The fanout strategy to use (see FanoutPolicy). Returns ------- - CppNode Streaming node representing the fanout operation. Raises From a315b629124527c4965a0e09c9f9f509f9db979c Mon Sep 17 00:00:00 2001 From: Niranda Perera Date: Tue, 25 Nov 2025 12:23:38 -0800 Subject: [PATCH 41/43] Apply suggestions from code review Co-authored-by: Mads R. B. Kristensen --- .../rapidsmpf/rapidsmpf/streaming/core/fanout.pyx | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx index ed786e43f..fca3bb6b3 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx @@ -11,8 +11,7 @@ from rapidsmpf.streaming.core.fanout cimport FanoutPolicy from rapidsmpf.streaming.core.node cimport CppNode, cpp_Node -def fanout(Context ctx, Channel ch_in, list[Channel] chs_out, FanoutPolicy policy) \ - -> CppNode: +def fanout(Context ctx, Channel ch_in, chs_out, FanoutPolicy policy): """ Broadcast messages from one input channel to multiple output channels. @@ -25,18 +24,18 @@ def fanout(Context ctx, Channel ch_in, list[Channel] chs_out, FanoutPolicy polic Parameters ---------- - ctx : + ctx The node context to use. - ch_in : + ch_in Input channel from which messages are received. - chs_out : + chs_out Output channels to which messages are broadcast. - policy : + policy The fanout strategy to use (see FanoutPolicy). Returns ------- - Streaming node representing the fanout operation. + Streaming node representing the fanout operation. Raises ------ From 79bfc0a2ce92f9d2682ebbc8f419f2a15f2ddf1a Mon Sep 17 00:00:00 2001 From: Niranda Perera Date: Tue, 25 Nov 2025 12:24:43 -0800 Subject: [PATCH 42/43] Apply suggestions from code review Co-authored-by: Mads R. B. Kristensen --- cpp/src/streaming/core/fanout.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/src/streaming/core/fanout.cpp b/cpp/src/streaming/core/fanout.cpp index 53bb3e628..1b9631daa 100644 --- a/cpp/src/streaming/core/fanout.cpp +++ b/cpp/src/streaming/core/fanout.cpp @@ -257,7 +257,6 @@ struct UnboundedFanout { } } } - co_await ch_out->drain(ctx.executor()); } @@ -302,7 +301,6 @@ struct UnboundedFanout { auto it = std::ranges::begin(filtered_view); // advance to first valid idx auto end = std::ranges::end(filtered_view); - if (it == end) { // no valid indices, so all send tasks are in an invalid state return true; From 8068f2ec6b209ecf8bb9bf7ab801f4d094bf3dbd Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 25 Nov 2025 12:32:59 -0800 Subject: [PATCH 43/43] Addressing PR comments Signed-off-by: niranda perera --- python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx index fca3bb6b3..a5edaef5b 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx +++ b/python/rapidsmpf/rapidsmpf/streaming/core/fanout.pyx @@ -31,8 +31,10 @@ def fanout(Context ctx, Channel ch_in, chs_out, FanoutPolicy policy): chs_out Output channels to which messages are broadcast. policy - The fanout strategy to use (see FanoutPolicy). - + The fanout policy to use. `FanoutPolicy.BOUNDED` can be used if all + output channels are being consumed by independent consumers in the + downstream. `FanoutPolicy.UNBOUNDED` can be used if the output channels + are being consumed by a single/ shared consumer in the downstream. Returns ------- Streaming node representing the fanout operation.