diff --git a/docs/postgres.md b/docs/postgres.md index f8fde58..3374065 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -58,4 +58,94 @@ const auto minors = query(conn); - Error handling through `Result` - Resource management through `Ref` - Customizable connection parameters (host, port, database name, etc.) + - LISTEN/NOTIFY for real-time event notifications + +## LISTEN/NOTIFY + +PostgreSQL provides a simple publish-subscribe mechanism through `LISTEN` and `NOTIFY` commands. This allows database clients to receive real-time notifications when events occur, without polling. Any client can send a notification to a channel, and all clients listening on that channel will receive it asynchronously. + +> **Note:** You should use a dedicated connection for LISTEN/NOTIFY, separate from your main database activity and outside any connection pool. This is because the listening connection must remain open and persistent to receive notifications, and connection pools typically recycle connections which would lose the LISTEN state. + +### API + +The `Connection` class provides the following methods for listen/notify: + +```cpp +// Subscribe to a notification channel +rfl::Result listen(const std::string& channel) noexcept; + +// Unsubscribe from a notification channel +rfl::Result unlisten(const std::string& channel) noexcept; + +// Send a notification to a channel with an optional payload +rfl::Result notify(const std::string& channel, + const std::string& payload = "") noexcept; + +// Consume input from the server (must be called before get_notifications) +bool consume_input() noexcept; + +// Retrieve all pending notifications +std::list get_notifications() noexcept; +``` + +The `Notification` struct contains: + +```cpp +struct Notification { + std::string channel; // The channel name + std::string payload; // The notification payload (may be empty) + int backend_pid; // The PID of the notifying backend +}; +``` + +### Subscribing to Channels + +```cpp +auto conn = sqlgen::postgres::connect(creds); +if (!conn) { + // Handle error... + return; +} + +// Subscribe to a channel +auto result = (*conn)->listen("my_channel"); +if (!result) { + // Handle error... +} +``` + +### Receiving Notifications + +To receive notifications, you must periodically call `consume_input()` to read data from the server, then `get_notifications()` to retrieve any pending notifications: + +```cpp +while (running) { + // Consume any available input from the server + if (!(*conn)->consume_input()) { + // Connection error + break; + } + + // Process any pending notifications + auto notifications = (*conn)->get_notifications(); + for (const auto& notification : notifications) { + // Handle the notification + std::cout << "Channel: " << notification.channel + << " Payload: " << notification.payload << std::endl; + } + + // Sleep briefly before checking again + std::this_thread::sleep_for(std::chrono::milliseconds(100)); +} +``` + +### Sending Notifications + +```cpp +// Send a notification with a payload +auto result = (*conn)->notify("my_channel", "event data here"); +if (!result) { + // Handle error... +} +``` diff --git a/include/sqlgen/postgres/Connection.hpp b/include/sqlgen/postgres/Connection.hpp index abaee34..45d6d58 100644 --- a/include/sqlgen/postgres/Connection.hpp +++ b/include/sqlgen/postgres/Connection.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "../Iterator.hpp" #include "../Ref.hpp" @@ -35,6 +36,18 @@ namespace sqlgen::postgres { +enum class NotificationWaitResult { + Ready, // Data available (possibly a NOTIFY) + Timeout, // Timeout elapsed + Error // I/O or connection error +}; + +struct Notification { + std::string channel; + std::string payload; + int backend_pid; +}; + class SQLGEN_API Connection { using Conn = PostgresV2Connection; @@ -86,6 +99,16 @@ class SQLGEN_API Connection { [&](const auto& _data) { return write_impl(_data); }, _begin, _end); } + std::list get_notifications() noexcept; + + rfl::Result listen(const std::string& channel) noexcept; + + rfl::Result unlisten(const std:: string& channel) noexcept; + + rfl::Result notify(const std::string& channel, const std::string& payload = "") noexcept; + + bool consume_input() noexcept; + private: Result insert_impl( const dynamic::Insert& _stmt, @@ -101,6 +124,8 @@ class SQLGEN_API Connection { Result write_impl( const std::vector>>& _data); + bool is_valid_channel_name(const std::string& s) const noexcept; + private: Conn conn_; }; diff --git a/src/sqlgen/postgres/Connection.cpp b/src/sqlgen/postgres/Connection.cpp index 350e6ad..41e2fe1 100644 --- a/src/sqlgen/postgres/Connection.cpp +++ b/src/sqlgen/postgres/Connection.cpp @@ -41,6 +41,69 @@ Result Connection::end_write() { return Nothing{}; } +std::list Connection::get_notifications() noexcept { + std::list notices; + + // Safe to call even if no data — just returns true + if (!PQconsumeInput(conn_.ptr())) { + // Note: In pure wait/consume pattern, this should rarely happen if socket is healthy + // But we don't error here — just skip + return notices; + } + + PGnotify* notify; + while ((notify = PQnotifies(conn_.ptr())) != nullptr) { + notices.push_back({ + .channel = std::string(notify->relname), + .payload = notify->extra[0] ? std::string(notify->extra) : "", + .backend_pid = notify->be_pid + }); + PQfreemem(notify); + } + + return notices; +} + +rfl::Result Connection::listen(const std::string& channel) noexcept { + if (!is_valid_channel_name(channel)) { + return error("Invalid channel name: must be a PostgreSQL identifier"); + } + const std::string sql = "LISTEN " + channel; + return execute(sql); +} + +rfl::Result Connection::unlisten(const std::string& channel) noexcept { + if (channel == "*") { + return execute("UNLISTEN *"); + } + if (!is_valid_channel_name(channel)) { + return error("Invalid channel name"); + } + const std::string sql = "UNLISTEN " + channel; + return execute(sql); +} + +rfl::Result Connection::notify(const std::string& channel, const std::string& payload) noexcept { + if (!is_valid_channel_name(channel)) { + return error("Invalid channel name"); + } + + auto* escaped_payload = PQescapeLiteral(conn_.ptr(), payload.c_str(), payload.size()); + if (!escaped_payload) { + return error("Failed to escape NOTIFY payload"); + } + const std::string sql = "NOTIFY " + channel + ", " + std::string(escaped_payload); + PQfreemem(escaped_payload); + + auto result = execute(sql); + PQflush(conn_.ptr()); + return result; +} + +bool Connection::consume_input() noexcept { + return PQconsumeInput(conn_.ptr()) == 1; +} + Result Connection::insert_impl( const dynamic::Insert& _stmt, const std::vector>>& @@ -160,5 +223,14 @@ Result Connection::write_impl( return Nothing{}; } -} // namespace sqlgen::postgres +bool Connection::is_valid_channel_name(const std::string& s) const noexcept { + if (s.empty()) return false; + const char first = s[0]; + if (first != '_' && !std::isalpha(static_cast(first))) + return false; + return std::all_of(s.begin() + 1, s.end(), [](char c) { + return c == '_' || std::isalnum(static_cast(c)); + }); +} +} // namespace sqlgen::postgres diff --git a/tests/postgres/test_listen_notify.cpp b/tests/postgres/test_listen_notify.cpp new file mode 100644 index 0000000..73644b4 --- /dev/null +++ b/tests/postgres/test_listen_notify.cpp @@ -0,0 +1,202 @@ +#ifndef SQLGEN_BUILD_DRY_TESTS_ONLY + +#include + +#include +#include +#include +#include + +namespace test_listen_notify { + +using namespace sqlgen; +using namespace sqlgen::postgres; + +std::list wait_for_notifications(auto& conn, + std::chrono::milliseconds timeout = std::chrono::milliseconds{2000}) { + const auto deadline = std::chrono::steady_clock::now() + timeout; + std::list all; + + while (std::chrono::steady_clock::now() < deadline) { + auto batch = conn->get_notifications(); + if (!batch.empty()) { + all.splice(all.end(), batch); // efficient for list + // Continue looping briefly in case more arrived + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + } + + return all; +} + +TEST(postgres, basic_listen_notify) { + using namespace sqlgen; + using namespace sqlgen::postgres; + + const auto credentials = sqlgen::postgres::Credentials{.user = "postgres", + .password = "password", + .host = "localhost", + .dbname = "postgres"}; + + auto listener_result = sqlgen::postgres::connect(credentials); + auto sender_result = sqlgen::postgres::connect(credentials); + + ASSERT_TRUE(listener_result); + ASSERT_TRUE(sender_result); + + auto listener = listener_result.value(); + auto sender = sender_result.value(); + + // Listener subscribes to channel + auto listen_res = listener->listen("test_channel"); + ASSERT_TRUE(listen_res) << listen_res.error().what(); + + // Sender sends a notification + auto notify_res = sender->notify("test_channel", "hello world"); + ASSERT_TRUE(notify_res) << notify_res.error().what(); + + // Listener waits and receives + auto notifications = wait_for_notifications(listener); + ASSERT_EQ(notifications.size(), 1); + EXPECT_EQ(notifications.front().channel, "test_channel"); + EXPECT_EQ(notifications.front().payload, "hello world"); + EXPECT_GT(notifications.front().backend_pid, 0); +} + +TEST(postgres, notify_without_listener_is_silent) { + using namespace sqlgen; + using namespace sqlgen::postgres; + + const auto credentials = sqlgen::postgres::Credentials{.user = "postgres", + .password = "password", + .host = "localhost", + .dbname = "postgres"}; + + auto sender_result = sqlgen::postgres::connect(credentials); + + ASSERT_TRUE(sender_result); + + auto sender = sender_result.value(); + + // Notify on a channel with no listener → should not error + auto res = sender->notify("unused_channel", "payload"); + ASSERT_TRUE(res) << res.error().what(); +} + +TEST(postgres, InvalidChannelNameRejected) { + using namespace sqlgen; + using namespace sqlgen::postgres; + + const auto credentials = sqlgen::postgres::Credentials{.user = "postgres", + .password = "password", + .host = "localhost", + .dbname = "postgres"}; + + auto listener_result = sqlgen::postgres::connect(credentials); + auto sender_result = sqlgen::postgres::connect(credentials); + + ASSERT_TRUE(listener_result); + ASSERT_TRUE(sender_result); + + auto listener = listener_result.value(); + auto sender = sender_result.value(); + + auto conn = postgres::connect(credentials); + ASSERT_TRUE(conn); + + // Invalid: starts with digit + EXPECT_FALSE(listener->listen("123chan")); + EXPECT_FALSE(sender->notify("123chan")); + + // Invalid: contains hyphen + EXPECT_FALSE(listener->listen("my-chan")); + EXPECT_FALSE(sender->notify("my-chan")); + + // Valid: underscore + alphanumeric + EXPECT_TRUE(listener->listen("_chan1")); + EXPECT_TRUE(sender->unlisten("_chan1")); +} + +TEST(postgres, unlisten_star) { + using namespace sqlgen; + using namespace sqlgen::postgres; + + const auto credentials = sqlgen::postgres::Credentials{.user = "postgres", + .password = "password", + .host = "localhost", + .dbname = "postgres"}; + + auto listener_result = sqlgen::postgres::connect(credentials); + + ASSERT_TRUE(listener_result); + + auto listener = listener_result.value(); + + ASSERT_TRUE(listener->listen("chan_a")); + ASSERT_TRUE(listener->listen("chan_b")); + + // Unlisten all + ASSERT_TRUE(listener->unlisten("*")); + + // Notify won't be received, but we just verify no error + auto sender = postgres::connect(credentials); + ASSERT_TRUE(listener->notify("chan_a", "test")); +} + +TEST(postgres, multiple_notifications_in_burst) { + using namespace sqlgen; + using namespace sqlgen::postgres; + + const auto credentials = sqlgen::postgres::Credentials{.user = "postgres", + .password = "password", + .host = "localhost", + .dbname = "postgres"}; + + auto listener_result = sqlgen::postgres::connect(credentials); + auto sender_result = sqlgen::postgres::connect(credentials); + + ASSERT_TRUE(listener_result); + ASSERT_TRUE(sender_result); + + auto listener = listener_result.value(); + auto sender = sender_result.value(); + + const std::string channel = "burst_channel"; + ASSERT_TRUE(listener->listen(channel)); + + const int num_notifications = 5; + std::list expected_payloads; + for (int i = 0; i < num_notifications; ++i) { + const std::string payload = "msg_" + std::to_string(i); + expected_payloads.push_back(payload); + ASSERT_TRUE(sender->notify(channel, payload)); + // Small delay to improve reliability on slow CI (optional but safe) + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + // Drain all notifications with retry + std::list notifications; + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(2); + while (notifications.size() < num_notifications && std::chrono::steady_clock::now() < deadline) { + auto batch = wait_for_notifications(listener, std::chrono::milliseconds{100}); + std::move(batch.begin(), batch.end(), std::back_inserter(notifications)); + } + + ASSERT_EQ(notifications.size(), expected_payloads.size()); + + auto expected_it = expected_payloads.begin(); + auto notify_it = notifications.begin(); + int i = 0; + + for (; expected_it != expected_payloads.end(); ++expected_it, ++notify_it, ++i) { + EXPECT_EQ(notify_it->channel, channel) << "Notification #" << i; + EXPECT_EQ(notify_it->payload, *expected_it) << "Notification #" << i; + EXPECT_GT(notify_it->backend_pid, 0) << "Notification #" << i; + } +} + +} + +#endif