diff --git a/modules/nats/example_basic.py b/modules/nats/example_basic.py index 9e941bf9b..094b94856 100644 --- a/modules/nats/example_basic.py +++ b/modules/nats/example_basic.py @@ -14,10 +14,10 @@ async def message_handler(msg: Msg): async def basic_example(): - with NatsContainer() as nats_container: + with NatsContainer(jetstream=True) as nats_container: # Get connection parameters host = nats_container.get_container_host_ip() - port = nats_container.get_exposed_port(nats_container.port) + port = nats_container.get_exposed_port(nats_container.client_port) # Create NATS client nc = NATS() @@ -32,7 +32,7 @@ async def basic_example(): print(f"\nCreated stream: {stream.config.name}") # Create consumer - consumer = await js.add_consumer(stream_name="test-stream", durable_name="test-consumer") + consumer = await js.add_consumer(stream="test-stream", durable_name="test-consumer") print(f"Created consumer: {consumer.name}") # Subscribe to subjects diff --git a/modules/nats/testcontainers/nats/__init__.py b/modules/nats/testcontainers/nats/__init__.py index 8ffeca4da..a900036dc 100644 --- a/modules/nats/testcontainers/nats/__init__.py +++ b/modules/nats/testcontainers/nats/__init__.py @@ -47,6 +47,7 @@ def __init__( management_port: int = 8222, expected_ready_log: str = "Server is ready", ready_timeout_secs: int = 120, + jetstream: bool = False, **kwargs, ) -> None: super().__init__(image, **kwargs) @@ -55,6 +56,8 @@ def __init__( self._expected_ready_log = expected_ready_log self._ready_timeout_secs = max(ready_timeout_secs, 0) self.with_exposed_ports(self.client_port, self.management_port) + if jetstream: + self.with_command("-js") @wait_container_is_ready() def _healthcheck(self) -> None: diff --git a/modules/nats/tests/test_nats_jetstream.py b/modules/nats/tests/test_nats_jetstream.py new file mode 100644 index 000000000..368e8c36f --- /dev/null +++ b/modules/nats/tests/test_nats_jetstream.py @@ -0,0 +1,32 @@ +from testcontainers.nats import NatsContainer +from uuid import uuid4 +import pytest + +from nats import connect as nats_connect +from nats.aio.client import Client as NATSClient + + +async def get_client(container: NatsContainer) -> "NATSClient": + """ + Get a nats client. + + Returns: + client: Nats client to connect to the container. + """ + conn_string = container.nats_uri() + client = await nats_connect(conn_string) + return client + + +@pytest.mark.asyncio +async def test_jetstream_add_stream(anyio_backend): + with NatsContainer(jetstream=True) as container: + nc: NATSClient = await get_client(container) + + topic = str(uuid4()) + + js = nc.jetstream() + + await js.add_stream(name="test-stream", subjects=[topic]) + + await nc.close()