From f7f55852921f366edc7aa39df813bc69ec64a54f Mon Sep 17 00:00:00 2001 From: Yehoav Rabinovich Date: Wed, 31 Dec 2025 10:54:09 +0200 Subject: [PATCH 1/9] feat: Add avatar-live and reference image support - Add avatar-live realtime model with set_image() method - Add lucy-restyle-v2v video model with reference_image support - Add AvatarOptions for configuring avatar image on connect - Add VideoRestyleInput with mutual exclusivity validation - Add comprehensive tests for both features --- decart/__init__.py | 6 +- decart/models.py | 54 ++++++- decart/queue/request.py | 2 +- decart/realtime/__init__.py | 3 +- decart/realtime/client.py | 85 ++++++++++- decart/realtime/messages.py | 26 +++- decart/realtime/types.py | 11 +- decart/realtime/webrtc_connection.py | 72 ++++++++- decart/realtime/webrtc_manager.py | 15 +- tests/test_models.py | 13 ++ tests/test_queue.py | 98 ++++++++++++ tests/test_realtime_unit.py | 215 +++++++++++++++++++++++++++ 12 files changed, 581 insertions(+), 19 deletions(-) diff --git a/decart/__init__.py b/decart/__init__.py index 988bc73..452a3dc 100644 --- a/decart/__init__.py +++ b/decart/__init__.py @@ -12,7 +12,7 @@ QueueResultError, TokenCreateError, ) -from .models import models, ModelDefinition +from .models import models, ModelDefinition, VideoRestyleInput from .types import FileInput, ModelState, Prompt from .queue import ( QueueClient, @@ -31,6 +31,7 @@ RealtimeClient, RealtimeConnectOptions, ConnectionState, + AvatarOptions, ) REALTIME_AVAILABLE = True @@ -39,6 +40,7 @@ RealtimeClient = None # type: ignore RealtimeConnectOptions = None # type: ignore ConnectionState = None # type: ignore + AvatarOptions = None # type: ignore __version__ = "0.0.1" @@ -56,6 +58,7 @@ "QueueResultError", "models", "ModelDefinition", + "VideoRestyleInput", "FileInput", "ModelState", "Prompt", @@ -75,5 +78,6 @@ "RealtimeClient", "RealtimeConnectOptions", "ConnectionState", + "AvatarOptions", ] ) diff --git a/decart/models.py b/decart/models.py index fa87f0f..6bb3b9a 100644 --- a/decart/models.py +++ b/decart/models.py @@ -1,10 +1,10 @@ from typing import Literal, Optional, List, Generic, TypeVar -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field, ConfigDict, model_validator from .errors import ModelNotFoundError from .types import FileInput, MotionTrajectoryInput -RealTimeModels = Literal["mirage", "mirage_v2", "lucy_v2v_720p_rt"] +RealTimeModels = Literal["mirage", "mirage_v2", "lucy_v2v_720p_rt", "avatar-live"] VideoModels = Literal[ "lucy-dev-i2v", "lucy-fast-v2v", @@ -13,6 +13,7 @@ "lucy-pro-v2v", "lucy-pro-flf2v", "lucy-motion", + "lucy-restyle-v2v", ] ImageModels = Literal["lucy-pro-t2i", "lucy-pro-i2i"] Model = Literal[RealTimeModels, VideoModels, ImageModels] @@ -95,6 +96,38 @@ class ImageToMotionVideoInput(DecartBaseModel): resolution: Optional[str] = None +class VideoRestyleInput(DecartBaseModel): + """Input for lucy-restyle-v2v model. + + Must provide either `prompt` OR `reference_image`, but not both. + `enhance_prompt` is only valid when using `prompt`, not `reference_image`. + """ + + prompt: Optional[str] = Field(default=None, min_length=1, max_length=1000) + reference_image: Optional[FileInput] = None + data: FileInput + seed: Optional[int] = None + resolution: Optional[str] = None + enhance_prompt: Optional[bool] = None + + @model_validator(mode="after") + def validate_prompt_or_reference_image(self) -> "VideoRestyleInput": + has_prompt = self.prompt is not None + has_reference_image = self.reference_image is not None + + if has_prompt == has_reference_image: + raise ValueError( + "Must provide either 'prompt' or 'reference_image', but not both" + ) + + if has_reference_image and self.enhance_prompt is not None: + raise ValueError( + "'enhance_prompt' is only valid when using 'prompt', not 'reference_image'" + ) + + return self + + class TextToImageInput(BaseModel): prompt: str = Field( ..., @@ -144,6 +177,14 @@ class ImageToImageInput(DecartBaseModel): height=704, input_schema=BaseModel, ), + "avatar-live": ModelDefinition( + name="avatar-live", + url_path="/v1/avatar-live/stream", + fps=25, + width=1280, + height=720, + input_schema=BaseModel, + ), }, "video": { "lucy-dev-i2v": ModelDefinition( @@ -202,6 +243,14 @@ class ImageToImageInput(DecartBaseModel): height=704, input_schema=ImageToMotionVideoInput, ), + "lucy-restyle-v2v": ModelDefinition( + name="lucy-restyle-v2v", + url_path="/v1/generate/lucy-restyle-v2v", + fps=25, + width=1280, + height=704, + input_schema=VideoRestyleInput, + ), }, "image": { "lucy-pro-t2i": ModelDefinition( @@ -247,6 +296,7 @@ def video(model: VideoModels) -> VideoModelDefinition: - "lucy-dev-i2v" - Image-to-video (Dev quality) - "lucy-fast-v2v" - Video-to-video (Fast quality) - "lucy-motion" - Image-to-motion-video + - "lucy-restyle-v2v" - Video-to-video with prompt or reference image """ try: return _MODELS["video"][model] # type: ignore[return-value] diff --git a/decart/queue/request.py b/decart/queue/request.py index cca60b6..c242b62 100644 --- a/decart/queue/request.py +++ b/decart/queue/request.py @@ -24,7 +24,7 @@ async def submit_job( for key, value in inputs.items(): if value is not None: - if key in ("data", "start", "end"): + if key in ("data", "start", "end", "reference_image"): content, content_type = await file_input_to_bytes(value, session) form_data.add_field(key, content, content_type=content_type) else: diff --git a/decart/realtime/__init__.py b/decart/realtime/__init__.py index fd681ec..cb20fd2 100644 --- a/decart/realtime/__init__.py +++ b/decart/realtime/__init__.py @@ -1,8 +1,9 @@ from .client import RealtimeClient -from .types import RealtimeConnectOptions, ConnectionState +from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions __all__ = [ "RealtimeClient", "RealtimeConnectOptions", "ConnectionState", + "AvatarOptions", ] diff --git a/decart/realtime/client.py b/decart/realtime/client.py index 37d79b2..bbee9b3 100644 --- a/decart/realtime/client.py +++ b/decart/realtime/client.py @@ -1,21 +1,33 @@ from typing import Callable, Optional import asyncio +import base64 import logging import uuid +import aiohttp from aiortc import MediaStreamTrack from .webrtc_manager import WebRTCManager, WebRTCConfiguration -from .messages import PromptMessage +from .messages import PromptMessage, SetAvatarImageMessage from .types import ConnectionState, RealtimeConnectOptions +from ..types import FileInput from ..errors import DecartSDKError, InvalidInputError, WebRTCError +from ..process.request import file_input_to_bytes logger = logging.getLogger(__name__) class RealtimeClient: - def __init__(self, manager: WebRTCManager, session_id: str): + def __init__( + self, + manager: WebRTCManager, + session_id: str, + http_session: Optional[aiohttp.ClientSession] = None, + is_avatar_live: bool = False, + ): self._manager = manager self.session_id = session_id + self._http_session = http_session + self._is_avatar_live = is_avatar_live self._connection_callbacks: list[Callable[[ConnectionState], None]] = [] self._error_callbacks: list[Callable[[DecartSDKError], None]] = [] @@ -24,7 +36,7 @@ async def connect( cls, base_url: str, api_key: str, - local_track: MediaStreamTrack, + local_track: Optional[MediaStreamTrack], options: RealtimeConnectOptions, integration: Optional[str] = None, ) -> "RealtimeClient": @@ -32,6 +44,8 @@ async def connect( ws_url = f"{base_url}{options.model.url_path}" ws_url += f"?api_key={api_key}&model={options.model.name}" + is_avatar_live = options.model.name == "avatar-live" + config = WebRTCConfiguration( webrtc_url=ws_url, api_key=api_key, @@ -43,16 +57,33 @@ async def connect( initial_state=options.initial_state, customize_offer=options.customize_offer, integration=integration, + is_avatar_live=is_avatar_live, ) + # Create HTTP session for file conversions + http_session = aiohttp.ClientSession() + manager = WebRTCManager(config) - client = cls(manager=manager, session_id=session_id) + client = cls( + manager=manager, + session_id=session_id, + http_session=http_session, + is_avatar_live=is_avatar_live, + ) config.on_connection_state_change = client._emit_connection_change config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error)) try: - await manager.connect(local_track) + # For avatar-live, convert and send avatar image before WebRTC connection + avatar_image_base64: Optional[str] = None + if is_avatar_live and options.avatar: + image_bytes, _ = await file_input_to_bytes( + options.avatar.avatar_image, http_session + ) + avatar_image_base64 = base64.b64encode(image_bytes).decode("utf-8") + + await manager.connect(local_track, avatar_image_base64=avatar_image_base64) if options.initial_state: if options.initial_state.prompt: @@ -61,6 +92,7 @@ async def connect( enrich=options.initial_state.prompt.enrich, ) except Exception as e: + await http_session.close() raise WebRTCError(str(e), cause=e) return client @@ -100,6 +132,47 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None: finally: self._manager.unregister_prompt_wait(prompt) + async def set_image(self, image: FileInput) -> None: + """Set or update the avatar image. + + Only available for avatar-live model. + + Args: + image: The image to set. Can be bytes, Path, URL string, or file-like object. + + Raises: + InvalidInputError: If not using avatar-live model or image is invalid. + DecartSDKError: If the server fails to acknowledge the image. + """ + if not self._is_avatar_live: + raise InvalidInputError("set_image() is only available for avatar-live model") + + if not self._http_session: + raise InvalidInputError("HTTP session not available") + + # Convert image to base64 + image_bytes, _ = await file_input_to_bytes(image, self._http_session) + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + + event, result = self._manager.register_image_set_wait() + + try: + await self._manager.send_message( + SetAvatarImageMessage(type="set_image", image_data=image_base64) + ) + + try: + await asyncio.wait_for(event.wait(), timeout=15.0) + except asyncio.TimeoutError: + raise DecartSDKError("Image set acknowledgment timed out") + + if not result["success"]: + raise DecartSDKError( + result.get("status") or "Failed to set avatar image" + ) + finally: + self._manager.unregister_image_set_wait() + def is_connected(self) -> bool: return self._manager.is_connected() @@ -108,6 +181,8 @@ def get_connection_state(self) -> ConnectionState: async def disconnect(self) -> None: await self._manager.cleanup() + if self._http_session and not self._http_session.closed: + await self._http_session.close() def on(self, event: str, callback: Callable) -> None: if event == "connection_change": diff --git a/decart/realtime/messages.py b/decart/realtime/messages.py index c59c5c6..cb4e97a 100644 --- a/decart/realtime/messages.py +++ b/decart/realtime/messages.py @@ -51,9 +51,22 @@ class PromptAckMessage(BaseModel): error: Optional[str] = None +class ImageSetMessage(BaseModel): + """Acknowledgment for avatar image set from server.""" + + type: Literal["image_set"] + status: str + + # Discriminated union for incoming messages IncomingMessage = Annotated[ - Union[AnswerMessage, IceCandidateMessage, SessionIdMessage, PromptAckMessage], + Union[ + AnswerMessage, + IceCandidateMessage, + SessionIdMessage, + PromptAckMessage, + ImageSetMessage, + ], Field(discriminator="type"), ] @@ -79,8 +92,17 @@ class PromptMessage(BaseModel): enhance_prompt: bool = True +class SetAvatarImageMessage(BaseModel): + """Set avatar image message.""" + + type: Literal["set_image"] + image_data: str # Base64-encoded image + + # Outgoing message union (no discriminator needed - we know what we're sending) -OutgoingMessage = Union[OfferMessage, IceCandidateMessage, PromptMessage] +OutgoingMessage = Union[ + OfferMessage, IceCandidateMessage, PromptMessage, SetAvatarImageMessage +] def parse_incoming_message(data: dict) -> IncomingMessage: diff --git a/decart/realtime/types.py b/decart/realtime/types.py index 7234e24..cde9f59 100644 --- a/decart/realtime/types.py +++ b/decart/realtime/types.py @@ -1,7 +1,7 @@ from typing import Literal, Callable, Optional from dataclasses import dataclass from ..models import ModelDefinition -from ..types import ModelState +from ..types import ModelState, FileInput try: from aiortc import MediaStreamTrack @@ -12,9 +12,18 @@ ConnectionState = Literal["connecting", "connected", "disconnected"] +@dataclass +class AvatarOptions: + """Options for avatar-live model.""" + + avatar_image: FileInput + """The avatar image to use. Can be bytes, Path, URL string, or file-like object.""" + + @dataclass class RealtimeConnectOptions: model: ModelDefinition on_remote_stream: Callable[[MediaStreamTrack], None] initial_state: Optional[ModelState] = None customize_offer: Optional[Callable] = None + avatar: Optional[AvatarOptions] = None diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index ad087e4..c497221 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -22,6 +22,8 @@ IceCandidateMessage, IceCandidatePayload, PromptAckMessage, + ImageSetMessage, + SetAvatarImageMessage, OutgoingMessage, ) from .types import ConnectionState @@ -48,13 +50,16 @@ def __init__( self._ws_task: Optional[asyncio.Task] = None self._ice_candidates_queue: list[RTCIceCandidate] = [] self._pending_prompts: dict[str, tuple[asyncio.Event, dict]] = {} + self._pending_image_set: Optional[tuple[asyncio.Event, dict]] = None async def connect( self, url: str, - local_track: MediaStreamTrack, + local_track: Optional[MediaStreamTrack], timeout: float = 30, integration: Optional[str] = None, + is_avatar_live: bool = False, + avatar_image_base64: Optional[str] = None, ) -> None: try: await self._set_state("connecting") @@ -71,7 +76,11 @@ async def connect( self._ws_task = asyncio.create_task(self._receive_messages()) - await self._setup_peer_connection(local_track) + # For avatar-live, send avatar image before WebRTC handshake + if is_avatar_live and avatar_image_base64: + await self._send_avatar_image_and_wait(avatar_image_base64) + + await self._setup_peer_connection(local_track, is_avatar_live=is_avatar_live) await self._create_and_send_offer() @@ -90,7 +99,34 @@ async def connect( self._on_error(e) raise WebRTCError(str(e), cause=e) - async def _setup_peer_connection(self, local_track: MediaStreamTrack) -> None: + async def _send_avatar_image_and_wait( + self, image_base64: str, timeout: float = 15.0 + ) -> None: + """Send avatar image and wait for acknowledgment.""" + event, result = self.register_image_set_wait() + + try: + await self._send_message( + SetAvatarImageMessage(type="set_image", image_data=image_base64) + ) + + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + raise WebRTCError("Avatar image acknowledgment timed out") + + if not result["success"]: + raise WebRTCError( + f"Failed to set avatar image: {result.get('status', 'unknown error')}" + ) + finally: + self.unregister_image_set_wait() + + async def _setup_peer_connection( + self, + local_track: Optional[MediaStreamTrack], + is_avatar_live: bool = False, + ) -> None: config = RTCConfiguration(iceServers=[RTCIceServer(urls=["stun:stun.l.google.com:19302"])]) self._pc = RTCPeerConnection(configuration=config) @@ -128,8 +164,15 @@ async def on_connection_state_change(): async def on_ice_connection_state_change(): logger.debug(f"ICE connection state: {self._pc.iceConnectionState}") - self._pc.addTrack(local_track) - logger.debug("Added local track to peer connection") + # For avatar-live, add recv-only video transceiver + if is_avatar_live: + self._pc.addTransceiver("video", direction="recvonly") + logger.debug("Added video transceiver (recvonly) for avatar-live") + + # Add local audio track if provided + if local_track: + self._pc.addTrack(local_track) + logger.debug("Added local track to peer connection") async def _create_and_send_offer(self) -> None: logger.debug("Creating offer...") @@ -179,6 +222,8 @@ async def _handle_message(self, data: dict) -> None: logger.debug(f"Session ID: {message.session_id}") elif message.type == "prompt_ack": self._handle_prompt_ack(message) + elif message.type == "image_set": + self._handle_image_set(message) async def _handle_answer(self, sdp: str) -> None: logger.debug("Received answer from server") @@ -218,6 +263,23 @@ def _handle_prompt_ack(self, message: PromptAckMessage) -> None: result["error"] = message.error event.set() + def _handle_image_set(self, message: ImageSetMessage) -> None: + logger.debug(f"Received image_set: status={message.status}") + if self._pending_image_set: + event, result = self._pending_image_set + result["status"] = message.status + result["success"] = message.status == "success" + event.set() + + def register_image_set_wait(self) -> tuple[asyncio.Event, dict]: + event = asyncio.Event() + result: dict = {"success": False, "status": None} + self._pending_image_set = (event, result) + return event, result + + def unregister_image_set_wait(self) -> None: + self._pending_image_set = None + def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]: event = asyncio.Event() result: dict = {"success": False, "error": None} diff --git a/decart/realtime/webrtc_manager.py b/decart/realtime/webrtc_manager.py index d39a067..b9793ed 100644 --- a/decart/realtime/webrtc_manager.py +++ b/decart/realtime/webrtc_manager.py @@ -31,6 +31,7 @@ class WebRTCConfiguration: initial_state: Optional[ModelState] = None customize_offer: Optional[Callable] = None integration: Optional[str] = None + is_avatar_live: bool = False def _is_retryable_error(exception: Exception) -> bool: @@ -52,12 +53,18 @@ def __init__(self, configuration: WebRTCConfiguration): before_sleep=before_sleep_log(logger, logging.WARNING), reraise=True, ) - async def connect(self, local_track: MediaStreamTrack) -> bool: + async def connect( + self, + local_track: Optional[MediaStreamTrack], + avatar_image_base64: Optional[str] = None, + ) -> bool: try: await self._connection.connect( url=self._config.webrtc_url, local_track=local_track, integration=self._config.integration, + is_avatar_live=self._config.is_avatar_live, + avatar_image_base64=avatar_image_base64, ) return True except Exception as e: @@ -91,3 +98,9 @@ def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]: def unregister_prompt_wait(self, prompt: str) -> None: self._connection.unregister_prompt_wait(prompt) + + def register_image_set_wait(self) -> tuple[asyncio.Event, dict]: + return self._connection.register_image_set_wait() + + def unregister_image_set_wait(self) -> None: + self._connection.unregister_image_set_wait() diff --git a/tests/test_models.py b/tests/test_models.py index 702145f..0d0e0ef 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,6 +17,14 @@ def test_realtime_models() -> None: assert model.height == 704 assert model.url_path == "/v1/stream" + # avatar-live model + model = models.realtime("avatar-live") + assert model.name == "avatar-live" + assert model.fps == 25 + assert model.width == 1280 + assert model.height == 720 + assert model.url_path == "/v1/avatar-live/stream" + def test_video_models() -> None: model = models.video("lucy-pro-t2v") @@ -26,6 +34,11 @@ def test_video_models() -> None: model = models.video("lucy-pro-v2v") assert model.name == "lucy-pro-v2v" + # lucy-restyle-v2v model + model = models.video("lucy-restyle-v2v") + assert model.name == "lucy-restyle-v2v" + assert model.url_path == "/v1/generate/lucy-restyle-v2v" + def test_image_models() -> None: model = models.image("lucy-pro-t2i") diff --git a/tests/test_queue.py b/tests/test_queue.py index ad18fc5..3d287ee 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -261,3 +261,101 @@ async def test_queue_includes_user_agent_header() -> None: assert "User-Agent" in headers assert headers["User-Agent"].startswith("decart-python-sdk/") + + +# Tests for lucy-restyle-v2v with reference_image + + +@pytest.mark.asyncio +async def test_queue_restyle_with_prompt() -> None: + """Test lucy-restyle-v2v submission with text prompt.""" + client = DecartClient(api_key="test-key") + + with patch("decart.queue.client.submit_job") as mock_submit: + mock_submit.return_value = MagicMock(job_id="job-789", status="pending") + + job = await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "prompt": "Make it look like anime", + "data": b"fake video data", + "enhance_prompt": True, + } + ) + + assert job.job_id == "job-789" + assert job.status == "pending" + mock_submit.assert_called_once() + + +@pytest.mark.asyncio +async def test_queue_restyle_with_reference_image() -> None: + """Test lucy-restyle-v2v submission with reference image.""" + client = DecartClient(api_key="test-key") + + with patch("decart.queue.client.submit_job") as mock_submit: + mock_submit.return_value = MagicMock(job_id="job-890", status="pending") + + job = await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "reference_image": b"fake image data", + "data": b"fake video data", + } + ) + + assert job.job_id == "job-890" + assert job.status == "pending" + mock_submit.assert_called_once() + + +@pytest.mark.asyncio +async def test_queue_restyle_rejects_both_prompt_and_reference_image() -> None: + """Test that lucy-restyle-v2v rejects both prompt and reference_image.""" + client = DecartClient(api_key="test-key") + + with pytest.raises(DecartSDKError) as exc_info: + await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "prompt": "Make it anime", + "reference_image": b"fake image data", + "data": b"fake video data", + } + ) + + assert "either 'prompt' or 'reference_image'" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_queue_restyle_rejects_neither_prompt_nor_reference_image() -> None: + """Test that lucy-restyle-v2v rejects when neither prompt nor reference_image provided.""" + client = DecartClient(api_key="test-key") + + with pytest.raises(DecartSDKError) as exc_info: + await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "data": b"fake video data", + } + ) + + assert "either 'prompt' or 'reference_image'" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_queue_restyle_rejects_enhance_prompt_with_reference_image() -> None: + """Test that enhance_prompt is only valid with text prompt, not reference_image.""" + client = DecartClient(api_key="test-key") + + with pytest.raises(DecartSDKError) as exc_info: + await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "reference_image": b"fake image data", + "data": b"fake video data", + "enhance_prompt": True, + } + ) + + assert "enhance_prompt" in str(exc_info.value).lower() diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index fee649d..c5b0fef 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -278,3 +278,218 @@ async def set_event(): assert "Server rejected prompt" in str(exc_info.value) mock_manager.unregister_prompt_wait.assert_called_with("New prompt") + + +# Tests for avatar-live model + + +def test_avatar_live_model_available(): + """Test that avatar-live model is available""" + model = models.realtime("avatar-live") + assert model.name == "avatar-live" + assert model.fps == 25 + assert model.width == 1280 + assert model.height == 720 + assert model.url_path == "/v1/avatar-live/stream" + + +@pytest.mark.asyncio +async def test_avatar_live_connect_with_avatar_image(): + """Test avatar-live connection with avatar image option""" + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.is_connected = MagicMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"fake image data", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions, AvatarOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("avatar-live"), + on_remote_stream=lambda t: None, + avatar=AvatarOptions(avatar_image=b"fake image bytes"), + ), + ) + + assert realtime_client is not None + assert realtime_client._is_avatar_live is True + mock_file_input.assert_called_once() + # Verify avatar_image_base64 was passed to connect + mock_manager.connect.assert_called_once() + call_kwargs = mock_manager.connect.call_args[1] + assert "avatar_image_base64" in call_kwargs + assert call_kwargs["avatar_image_base64"] is not None + + +@pytest.mark.asyncio +async def test_avatar_live_set_image(): + """Test set_image method for avatar-live""" + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + image_set_event = asyncio.Event() + image_set_result = {"success": True, "status": "success"} + + mock_manager.register_image_set_wait = MagicMock( + return_value=(image_set_event, image_set_result) + ) + mock_manager.unregister_image_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"new image data", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("avatar-live"), + on_remote_stream=lambda t: None, + ), + ) + + async def set_event(): + await asyncio.sleep(0.01) + image_set_event.set() + + asyncio.create_task(set_event()) + await realtime_client.set_image(b"new avatar image") + + mock_manager.send_message.assert_called() + call_args = mock_manager.send_message.call_args[0][0] + assert call_args.type == "set_image" + assert call_args.image_data is not None + mock_manager.unregister_image_set_wait.assert_called_once() + + +@pytest.mark.asyncio +async def test_set_image_only_for_avatar_live(): + """Test that set_image raises error for non-avatar-live models""" + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + from decart.errors import InvalidInputError + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("mirage"), # Not avatar-live + on_remote_stream=lambda t: None, + ), + ) + + with pytest.raises(InvalidInputError) as exc_info: + await realtime_client.set_image(b"test image") + + assert "avatar-live" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_avatar_live_set_image_timeout(): + """Test set_image raises on timeout""" + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + image_set_event = asyncio.Event() + image_set_result = {"success": False, "status": None} + + mock_manager.register_image_set_wait = MagicMock( + return_value=(image_set_event, image_set_result) + ) + mock_manager.unregister_image_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"image data", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + from decart.errors import DecartSDKError + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("avatar-live"), + on_remote_stream=lambda t: None, + ), + ) + + with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError): + with pytest.raises(DecartSDKError) as exc_info: + await realtime_client.set_image(b"test image") + + assert "timed out" in str(exc_info.value).lower() + mock_manager.unregister_image_set_wait.assert_called_once() From 403fcc7970957280edcb2c6aff61f5530feb4d95 Mon Sep 17 00:00:00 2001 From: Yehoav Rabinovich Date: Wed, 31 Dec 2025 10:56:51 +0200 Subject: [PATCH 2/9] black --- decart/models.py | 4 +--- decart/realtime/client.py | 4 +--- decart/realtime/messages.py | 4 +--- decart/realtime/webrtc_connection.py | 4 +--- 4 files changed, 4 insertions(+), 12 deletions(-) diff --git a/decart/models.py b/decart/models.py index 6bb3b9a..9d7aa58 100644 --- a/decart/models.py +++ b/decart/models.py @@ -116,9 +116,7 @@ def validate_prompt_or_reference_image(self) -> "VideoRestyleInput": has_reference_image = self.reference_image is not None if has_prompt == has_reference_image: - raise ValueError( - "Must provide either 'prompt' or 'reference_image', but not both" - ) + raise ValueError("Must provide either 'prompt' or 'reference_image', but not both") if has_reference_image and self.enhance_prompt is not None: raise ValueError( diff --git a/decart/realtime/client.py b/decart/realtime/client.py index bbee9b3..5266a87 100644 --- a/decart/realtime/client.py +++ b/decart/realtime/client.py @@ -167,9 +167,7 @@ async def set_image(self, image: FileInput) -> None: raise DecartSDKError("Image set acknowledgment timed out") if not result["success"]: - raise DecartSDKError( - result.get("status") or "Failed to set avatar image" - ) + raise DecartSDKError(result.get("status") or "Failed to set avatar image") finally: self._manager.unregister_image_set_wait() diff --git a/decart/realtime/messages.py b/decart/realtime/messages.py index cb4e97a..2b413df 100644 --- a/decart/realtime/messages.py +++ b/decart/realtime/messages.py @@ -100,9 +100,7 @@ class SetAvatarImageMessage(BaseModel): # Outgoing message union (no discriminator needed - we know what we're sending) -OutgoingMessage = Union[ - OfferMessage, IceCandidateMessage, PromptMessage, SetAvatarImageMessage -] +OutgoingMessage = Union[OfferMessage, IceCandidateMessage, PromptMessage, SetAvatarImageMessage] def parse_incoming_message(data: dict) -> IncomingMessage: diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index c497221..3d62f66 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -99,9 +99,7 @@ async def connect( self._on_error(e) raise WebRTCError(str(e), cause=e) - async def _send_avatar_image_and_wait( - self, image_base64: str, timeout: float = 15.0 - ) -> None: + async def _send_avatar_image_and_wait(self, image_base64: str, timeout: float = 15.0) -> None: """Send avatar image and wait for acknowledgment.""" event, result = self.register_image_set_wait() From 1d4e9134d88ca9121aeb6e39e617d6dc99515399 Mon Sep 17 00:00:00 2001 From: Yehoav Rabinovich Date: Wed, 31 Dec 2025 10:58:31 +0200 Subject: [PATCH 3/9] remove unused import --- tests/test_realtime_unit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index c5b0fef..50f9cb3 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -296,7 +296,6 @@ def test_avatar_live_model_available(): @pytest.mark.asyncio async def test_avatar_live_connect_with_avatar_image(): """Test avatar-live connection with avatar image option""" - import asyncio client = DecartClient(api_key="test-key") From cd2176b591acb3c54ffa665d18584c8c2256ac20 Mon Sep 17 00:00:00 2001 From: Yehoav Rabinovich Date: Wed, 31 Dec 2025 11:40:09 +0200 Subject: [PATCH 4/9] add examples --- examples/avatar_live.py | 179 ++++++++++++++++++++++++++++++++++++++ examples/video_restyle.py | 152 ++++++++++++++++++++++++++++++++ 2 files changed, 331 insertions(+) create mode 100644 examples/avatar_live.py create mode 100644 examples/video_restyle.py diff --git a/examples/avatar_live.py b/examples/avatar_live.py new file mode 100644 index 0000000..4e5f305 --- /dev/null +++ b/examples/avatar_live.py @@ -0,0 +1,179 @@ +""" +Avatar Live Example + +This example demonstrates how to use the avatar-live model to animate an avatar image. +The avatar can be animated with audio input (microphone or audio file). + +Usage: + # With audio file: + DECART_API_KEY=your-key python avatar_live.py avatar.png audio.mp3 + + # With just avatar image (will wait for audio): + DECART_API_KEY=your-key python avatar_live.py avatar.png + +Requirements: + pip install decart[realtime] +""" + +import asyncio +import os +import sys +from pathlib import Path + +try: + from aiortc.contrib.media import MediaPlayer, MediaRecorder +except ImportError: + print("aiortc is required for this example.") + print("Install with: pip install decart[realtime]") + sys.exit(1) + +from decart import DecartClient, models + + +async def main(): + api_key = os.getenv("DECART_API_KEY") + if not api_key: + print("Error: DECART_API_KEY environment variable not set") + print("Usage: DECART_API_KEY=your-key python avatar_live.py [audio_file]") + return + + if len(sys.argv) < 2: + print("Usage: python avatar_live.py [audio_file]") + print("") + print("Arguments:") + print(" avatar_image - Path to avatar image (PNG, JPG)") + print(" audio_file - Optional: Path to audio file for the avatar to speak") + print("") + print("Examples:") + print(" python avatar_live.py avatar.png") + print(" python avatar_live.py avatar.png speech.mp3") + return + + avatar_image = sys.argv[1] + if not os.path.exists(avatar_image): + print(f"Error: Avatar image not found: {avatar_image}") + return + + audio_file = sys.argv[2] if len(sys.argv) > 2 else None + if audio_file and not os.path.exists(audio_file): + print(f"Error: Audio file not found: {audio_file}") + return + + print(f"šŸ–¼ļø Avatar image: {avatar_image}") + if audio_file: + print(f"šŸ”Š Audio file: {audio_file}") + + # Load audio if provided + audio_track = None + if audio_file: + print("Loading audio file...") + player = MediaPlayer(audio_file) + if player.audio: + audio_track = player.audio + print("āœ“ Audio loaded") + else: + print("āš ļø Warning: No audio stream found in file, continuing without audio") + + try: + from decart.realtime.client import RealtimeClient + from decart.realtime.types import RealtimeConnectOptions, AvatarOptions + except ImportError: + print("Error: Realtime API not available") + print("Install with: pip install decart[realtime]") + return + + print("\nCreating Decart client...") + async with DecartClient(api_key=api_key) as client: + model = models.realtime("avatar-live") + print(f"Using model: {model.name}") + + frame_count = 0 + recorder = None + output_file = Path(f"output_avatar_live.mp4") + + def on_remote_stream(track): + nonlocal frame_count, recorder + frame_count += 1 + if frame_count % 25 == 0: + print(f"šŸ“¹ Received {frame_count} frames...") + + if recorder is None: + print(f"šŸ’¾ Recording to {output_file}") + recorder = MediaRecorder(str(output_file)) + recorder.addTrack(track) + asyncio.create_task(recorder.start()) + + def on_connection_change(state): + print(f"šŸ”„ Connection state: {state}") + + def on_error(error): + print(f"āŒ Error: {error.__class__.__name__} - {error.message}") + + print("\nConnecting to Avatar Live API...") + print("(Sending avatar image...)") + + try: + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=audio_track, # Can be None if no audio + options=RealtimeConnectOptions( + model=model, + on_remote_stream=on_remote_stream, + avatar=AvatarOptions(avatar_image=Path(avatar_image)), + ), + ) + + realtime_client.on("connection_change", on_connection_change) + realtime_client.on("error", on_error) + + print("āœ“ Connected!") + print(f"Session ID: {realtime_client.session_id}") + + if audio_file: + print("\nPlaying audio through avatar...") + print("(The avatar will animate based on the audio)") + else: + print("\nNo audio provided - avatar will be static") + print("You can update the avatar image dynamically using set_image()") + + print("\nPress Ctrl+C to stop and save the recording...") + + # Demo: Update avatar image after 5 seconds (if you want to test set_image) + # Uncomment the following to test dynamic image updates: + # await asyncio.sleep(5) + # print("Updating avatar image...") + # await realtime_client.set_image(Path("new_avatar.png")) + # print("āœ“ Avatar image updated!") + + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + print(f"\n\nāœ“ Received {frame_count} frames total") + + finally: + if recorder: + try: + print(f"šŸ’¾ Saving video to {output_file}...") + await asyncio.sleep(0.5) + await recorder.stop() + print(f"āœ“ Video saved to {output_file}") + except Exception as e: + print(f"āš ļø Warning: Could not save video cleanly: {e}") + print(" Video file may be incomplete") + + except Exception as e: + print(f"\nāŒ Connection failed: {e}") + import traceback + traceback.print_exc() + + finally: + if "realtime_client" in locals(): + print("\nDisconnecting...") + await realtime_client.disconnect() + print("āœ“ Disconnected") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/video_restyle.py b/examples/video_restyle.py new file mode 100644 index 0000000..4cef0a6 --- /dev/null +++ b/examples/video_restyle.py @@ -0,0 +1,152 @@ +""" +Video Restyle Example + +This example demonstrates how to use the lucy-restyle-v2v model to restyle a video +using either a text prompt OR a reference image. + +Usage: + # With text prompt: + DECART_API_KEY=your-key python video_restyle.py input.mp4 --prompt "anime style" + + # With reference image: + DECART_API_KEY=your-key python video_restyle.py input.mp4 --reference style.png + +Requirements: + pip install decart +""" + +import asyncio +import argparse +import os +import sys +from pathlib import Path + +from decart import DecartClient, models + + +async def main(): + parser = argparse.ArgumentParser( + description="Restyle a video using text prompt or reference image" + ) + parser.add_argument("video", help="Path to input video file") + parser.add_argument( + "--prompt", "-p", + help="Text prompt describing the style (e.g., 'anime style', 'oil painting')" + ) + parser.add_argument( + "--reference", "-r", + help="Path to reference image for style transfer" + ) + parser.add_argument( + "--output", "-o", + help="Output file path (default: output_restyle.mp4)" + ) + parser.add_argument( + "--seed", "-s", + type=int, + help="Random seed for reproducibility" + ) + parser.add_argument( + "--enhance", + action="store_true", + default=True, + help="Enhance the prompt (only with --prompt, default: True)" + ) + parser.add_argument( + "--no-enhance", + action="store_true", + help="Disable prompt enhancement" + ) + + args = parser.parse_args() + + # Validate arguments + if not args.prompt and not args.reference: + print("Error: Must provide either --prompt or --reference") + parser.print_help() + sys.exit(1) + + if args.prompt and args.reference: + print("Error: Cannot use both --prompt and --reference") + print(" Please choose one or the other") + sys.exit(1) + + api_key = os.getenv("DECART_API_KEY") + if not api_key: + print("Error: DECART_API_KEY environment variable not set") + sys.exit(1) + + video_path = Path(args.video) + if not video_path.exists(): + print(f"Error: Video file not found: {video_path}") + sys.exit(1) + + if args.reference: + ref_path = Path(args.reference) + if not ref_path.exists(): + print(f"Error: Reference image not found: {ref_path}") + sys.exit(1) + + output_path = args.output or f"output_restyle_{video_path.stem}.mp4" + + print("=" * 50) + print("Video Restyle") + print("=" * 50) + print(f"Input video: {video_path}") + if args.prompt: + print(f"Style: Text prompt - '{args.prompt}'") + print(f"Enhance prompt: {not args.no_enhance}") + else: + print(f"Style: Reference image - {args.reference}") + print(f"Output: {output_path}") + if args.seed: + print(f"Seed: {args.seed}") + print("=" * 50) + + async with DecartClient(api_key=api_key) as client: + # Build options + options = { + "model": models.video("lucy-restyle-v2v"), + "data": video_path, + } + + if args.prompt: + options["prompt"] = args.prompt + options["enhance_prompt"] = not args.no_enhance + else: + options["reference_image"] = Path(args.reference) + + if args.seed: + options["seed"] = args.seed + + def on_status_change(job): + status_emoji = { + "pending": "ā³", + "processing": "šŸ”„", + "completed": "āœ…", + "failed": "āŒ", + } + emoji = status_emoji.get(job.status, "•") + print(f"{emoji} Status: {job.status}") + + options["on_status_change"] = on_status_change + + print("\nSubmitting job...") + result = await client.queue.submit_and_poll(options) + + if result.status == "failed": + print(f"\nāŒ Job failed: {result.error}") + sys.exit(1) + + print(f"\nāœ… Job completed!") + print(f"šŸ’¾ Saving to {output_path}...") + + with open(output_path, "wb") as f: + f.write(result.data) + + print(f"āœ“ Video saved to {output_path}") + print(f" Size: {len(result.data) / 1024 / 1024:.2f} MB") + + +if __name__ == "__main__": + asyncio.run(main()) From 71315cd6376d7da339b14bc7944b86bc070e7ec9 Mon Sep 17 00:00:00 2001 From: Yehoav Rabinovich Date: Wed, 31 Dec 2025 13:42:41 +0200 Subject: [PATCH 5/9] black --- examples/avatar_live.py | 1 + examples/video_restyle.py | 29 ++++++++--------------------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/examples/avatar_live.py b/examples/avatar_live.py index 4e5f305..6f9b8f9 100644 --- a/examples/avatar_live.py +++ b/examples/avatar_live.py @@ -166,6 +166,7 @@ def on_error(error): except Exception as e: print(f"\nāŒ Connection failed: {e}") import traceback + traceback.print_exc() finally: diff --git a/examples/video_restyle.py b/examples/video_restyle.py index 4cef0a6..04af4b4 100644 --- a/examples/video_restyle.py +++ b/examples/video_restyle.py @@ -30,33 +30,20 @@ async def main(): ) parser.add_argument("video", help="Path to input video file") parser.add_argument( - "--prompt", "-p", - help="Text prompt describing the style (e.g., 'anime style', 'oil painting')" - ) - parser.add_argument( - "--reference", "-r", - help="Path to reference image for style transfer" - ) - parser.add_argument( - "--output", "-o", - help="Output file path (default: output_restyle.mp4)" - ) - parser.add_argument( - "--seed", "-s", - type=int, - help="Random seed for reproducibility" + "--prompt", + "-p", + help="Text prompt describing the style (e.g., 'anime style', 'oil painting')", ) + parser.add_argument("--reference", "-r", help="Path to reference image for style transfer") + parser.add_argument("--output", "-o", help="Output file path (default: output_restyle.mp4)") + parser.add_argument("--seed", "-s", type=int, help="Random seed for reproducibility") parser.add_argument( "--enhance", action="store_true", default=True, - help="Enhance the prompt (only with --prompt, default: True)" - ) - parser.add_argument( - "--no-enhance", - action="store_true", - help="Disable prompt enhancement" + help="Enhance the prompt (only with --prompt, default: True)", ) + parser.add_argument("--no-enhance", action="store_true", help="Disable prompt enhancement") args = parser.parse_args() From 95849422f50d1fb9d5553c9cbbce1df9fdf15c9c Mon Sep 17 00:00:00 2001 From: Yehoav Rabinovich Date: Wed, 31 Dec 2025 13:45:26 +0200 Subject: [PATCH 6/9] ruff --- examples/avatar_live.py | 2 +- examples/video_restyle.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/avatar_live.py b/examples/avatar_live.py index 6f9b8f9..bebb911 100644 --- a/examples/avatar_live.py +++ b/examples/avatar_live.py @@ -89,7 +89,7 @@ async def main(): frame_count = 0 recorder = None - output_file = Path(f"output_avatar_live.mp4") + output_file = Path("output_avatar_live.mp4") def on_remote_stream(track): nonlocal frame_count, recorder diff --git a/examples/video_restyle.py b/examples/video_restyle.py index 04af4b4..dd639d1 100644 --- a/examples/video_restyle.py +++ b/examples/video_restyle.py @@ -125,7 +125,7 @@ def on_status_change(job): print(f"\nāŒ Job failed: {result.error}") sys.exit(1) - print(f"\nāœ… Job completed!") + print("\nāœ… Job completed!") print(f"šŸ’¾ Saving to {output_path}...") with open(output_path, "wb") as f: From 1dcb385703c7cad6b55e3c15e78715b71ebba08c Mon Sep 17 00:00:00 2001 From: Yehoav Rabinovich Date: Wed, 31 Dec 2025 14:08:56 +0200 Subject: [PATCH 7/9] feat: Add interactive test UI for Decart SDK - Introduced a Gradio-based test UI for easy feature testing. - Added functionality for image and video generation, transformation, and restyling. - Included a tokens API for creating short-lived client tokens. - Updated README with instructions for using the test UI. --- README.md | 22 ++ test_ui.py | 612 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 634 insertions(+) create mode 100644 test_ui.py diff --git a/README.md b/README.md index b292ad6..dbfe4e8 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,28 @@ uv run python examples/realtime_synthetic.py uv lock --upgrade ``` +### Test UI + +The SDK includes an interactive test UI built with Gradio for quickly testing all SDK features without writing code. + +```bash +# Install Gradio +pip install gradio + +# Run the test UI +python test_ui.py +``` + +Then open http://localhost:7860 in your browser. + +The UI provides tabs for: +- **Image Generation** - Text-to-image and image-to-image transformations +- **Video Generation** - Text-to-video, image-to-video, and video-to-video +- **Video Restyle** - Restyle videos using text prompts or reference images +- **Tokens** - Create short-lived client tokens + +Enter your API key at the top of the interface to start testing. + ### Publishing a New Version The package is automatically published to PyPI when you create a GitHub release. diff --git a/test_ui.py b/test_ui.py new file mode 100644 index 0000000..89cf319 --- /dev/null +++ b/test_ui.py @@ -0,0 +1,612 @@ +#!/usr/bin/env python3 +""" +Decart SDK Test UI - Interactive testing interface for the Python SDK. + +Usage: + pip install gradio + python test_ui.py + +Then open http://localhost:7860 in your browser. +""" + +import asyncio +import gradio as gr +from pathlib import Path +from typing import Optional +import tempfile +import os + +# Import the SDK +from decart import DecartClient, models + + +def get_client(api_key: str) -> DecartClient: + """Create a Decart client with the given API key.""" + if not api_key or not api_key.strip(): + raise ValueError("Please enter an API key") + return DecartClient(api_key=api_key.strip()) + + +# ============================================================================ +# Image Processing (Process API) +# ============================================================================ + + +async def process_text_to_image( + api_key: str, + prompt: str, + seed: Optional[int], + resolution: str, + orientation: str, +) -> tuple[Optional[bytes], str]: + """Generate an image from text prompt.""" + try: + client = get_client(api_key) + + options = { + "model": models.image("lucy-pro-t2i"), + "prompt": prompt, + } + if seed: + options["seed"] = seed + if resolution and resolution != "default": + options["resolution"] = resolution + if orientation and orientation != "default": + options["orientation"] = orientation + + result = await client.process(options) + return result, f"Success! Generated image from prompt: '{prompt[:50]}...'" + except Exception as e: + return None, f"Error: {str(e)}" + + +async def process_image_to_image( + api_key: str, + prompt: str, + input_image: str, + seed: Optional[int], + strength: float, +) -> tuple[Optional[bytes], str]: + """Transform an image with a prompt.""" + try: + if not input_image: + return None, "Please upload an image" + + client = get_client(api_key) + + options = { + "model": models.image("lucy-pro-i2i"), + "prompt": prompt, + "data": Path(input_image), + } + if seed: + options["seed"] = seed + if strength: + options["strength"] = strength + + result = await client.process(options) + return result, f"Success! Transformed image with prompt: '{prompt[:50]}...'" + except Exception as e: + return None, f"Error: {str(e)}" + + +# ============================================================================ +# Video Processing (Queue API) +# ============================================================================ + + +async def process_video_t2v( + api_key: str, + prompt: str, + seed: Optional[int], + enhance_prompt: bool, + progress=gr.Progress(), +) -> tuple[Optional[str], str]: + """Generate a video from text prompt.""" + try: + client = get_client(api_key) + + options = { + "model": models.video("lucy-pro-t2v"), + "prompt": prompt, + } + if seed: + options["seed"] = seed + if enhance_prompt is not None: + options["enhance_prompt"] = enhance_prompt + + progress(0.1, desc="Submitting job...") + + def on_status_change(job): + if job.status == "pending": + progress(0.2, desc="Job pending...") + elif job.status == "processing": + progress(0.5, desc="Processing video...") + + options["on_status_change"] = on_status_change + + result = await client.queue.submit_and_poll(options) + + if result.status == "failed": + return None, f"Job failed: {result.error}" + + progress(0.9, desc="Saving video...") + + # Save to temp file + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(result.data) + return f.name, f"Success! Generated video from prompt: '{prompt[:50]}...'" + + except Exception as e: + return None, f"Error: {str(e)}" + + +async def process_video_v2v( + api_key: str, + prompt: str, + input_video: str, + seed: Optional[int], + enhance_prompt: bool, + progress=gr.Progress(), +) -> tuple[Optional[str], str]: + """Transform a video with a prompt.""" + try: + if not input_video: + return None, "Please upload a video" + + client = get_client(api_key) + + options = { + "model": models.video("lucy-pro-v2v"), + "prompt": prompt, + "data": Path(input_video), + } + if seed: + options["seed"] = seed + if enhance_prompt is not None: + options["enhance_prompt"] = enhance_prompt + + progress(0.1, desc="Submitting job...") + + def on_status_change(job): + if job.status == "pending": + progress(0.2, desc="Job pending...") + elif job.status == "processing": + progress(0.5, desc="Processing video...") + + options["on_status_change"] = on_status_change + + result = await client.queue.submit_and_poll(options) + + if result.status == "failed": + return None, f"Job failed: {result.error}" + + progress(0.9, desc="Saving video...") + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(result.data) + return f.name, f"Success! Transformed video with prompt: '{prompt[:50]}...'" + + except Exception as e: + return None, f"Error: {str(e)}" + + +async def process_video_restyle( + api_key: str, + input_video: str, + use_reference_image: bool, + prompt: str, + reference_image: Optional[str], + seed: Optional[int], + enhance_prompt: bool, + progress=gr.Progress(), +) -> tuple[Optional[str], str]: + """Restyle a video with prompt OR reference image.""" + try: + if not input_video: + return None, "Please upload a video" + + client = get_client(api_key) + + options = { + "model": models.video("lucy-restyle-v2v"), + "data": Path(input_video), + } + + if use_reference_image: + if not reference_image: + return None, "Please upload a reference image" + options["reference_image"] = Path(reference_image) + else: + if not prompt: + return None, "Please enter a prompt" + options["prompt"] = prompt + if enhance_prompt is not None: + options["enhance_prompt"] = enhance_prompt + + if seed: + options["seed"] = seed + + progress(0.1, desc="Submitting job...") + + def on_status_change(job): + if job.status == "pending": + progress(0.2, desc="Job pending...") + elif job.status == "processing": + progress(0.5, desc="Processing video...") + + options["on_status_change"] = on_status_change + + result = await client.queue.submit_and_poll(options) + + if result.status == "failed": + return None, f"Job failed: {result.error}" + + progress(0.9, desc="Saving video...") + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(result.data) + mode = "reference image" if use_reference_image else f"prompt: '{prompt[:30]}...'" + return f.name, f"Success! Restyled video with {mode}" + + except Exception as e: + return None, f"Error: {str(e)}" + + +async def process_video_i2v( + api_key: str, + prompt: str, + input_image: str, + seed: Optional[int], + enhance_prompt: bool, + progress=gr.Progress(), +) -> tuple[Optional[str], str]: + """Generate a video from an image.""" + try: + if not input_image: + return None, "Please upload an image" + + client = get_client(api_key) + + options = { + "model": models.video("lucy-pro-i2v"), + "prompt": prompt, + "data": Path(input_image), + } + if seed: + options["seed"] = seed + if enhance_prompt is not None: + options["enhance_prompt"] = enhance_prompt + + progress(0.1, desc="Submitting job...") + + def on_status_change(job): + if job.status == "pending": + progress(0.2, desc="Job pending...") + elif job.status == "processing": + progress(0.5, desc="Processing video...") + + options["on_status_change"] = on_status_change + + result = await client.queue.submit_and_poll(options) + + if result.status == "failed": + return None, f"Job failed: {result.error}" + + progress(0.9, desc="Saving video...") + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(result.data) + return f.name, f"Success! Generated video from image" + + except Exception as e: + return None, f"Error: {str(e)}" + + +# ============================================================================ +# Tokens API +# ============================================================================ + + +async def create_token(api_key: str) -> str: + """Create a short-lived client token.""" + try: + client = get_client(api_key) + result = await client.tokens.create() + return f"Success!\n\nToken: {result.api_key}\nExpires: {result.expires_at}" + except Exception as e: + return f"Error: {str(e)}" + + +# ============================================================================ +# Gradio UI +# ============================================================================ + + +def create_ui(): + """Create the Gradio interface.""" + + with gr.Blocks( + title="Decart SDK Test UI", + theme=gr.themes.Soft(), + css=""" + .status-success { color: green; } + .status-error { color: red; } + """, + ) as demo: + gr.Markdown( + """ + # Decart SDK Test UI + + Interactive testing interface for the Decart Python SDK. + Enter your API key below to get started. + """ + ) + + # API Key input (shared across all tabs) + api_key = gr.Textbox( + label="API Key", + placeholder="Enter your Decart API key", + type="password", + elem_id="api-key-input", + ) + + with gr.Tabs(): + # ================================================================ + # Image Processing Tab + # ================================================================ + with gr.TabItem("Image Generation"): + gr.Markdown("### Text to Image") + with gr.Row(): + with gr.Column(): + t2i_prompt = gr.Textbox( + label="Prompt", + placeholder="A beautiful sunset over mountains", + lines=3, + ) + with gr.Row(): + t2i_seed = gr.Number(label="Seed (optional)", precision=0) + t2i_resolution = gr.Dropdown( + label="Resolution", + choices=["default", "720p", "1080p"], + value="default", + ) + t2i_orientation = gr.Dropdown( + label="Orientation", + choices=["default", "landscape", "portrait", "square"], + value="default", + ) + t2i_btn = gr.Button("Generate Image", variant="primary") + with gr.Column(): + t2i_output = gr.Image(label="Generated Image", type="filepath") + t2i_status = gr.Textbox(label="Status", interactive=False) + + t2i_btn.click( + fn=lambda *args: asyncio.run(process_text_to_image(*args)), + inputs=[api_key, t2i_prompt, t2i_seed, t2i_resolution, t2i_orientation], + outputs=[t2i_output, t2i_status], + ) + + gr.Markdown("---") + gr.Markdown("### Image to Image") + with gr.Row(): + with gr.Column(): + i2i_input = gr.Image(label="Input Image", type="filepath") + i2i_prompt = gr.Textbox( + label="Prompt", + placeholder="Make it look like anime", + lines=2, + ) + with gr.Row(): + i2i_seed = gr.Number(label="Seed (optional)", precision=0) + i2i_strength = gr.Slider( + label="Strength", + minimum=0.0, + maximum=1.0, + value=0.75, + step=0.05, + ) + i2i_btn = gr.Button("Transform Image", variant="primary") + with gr.Column(): + i2i_output = gr.Image(label="Transformed Image", type="filepath") + i2i_status = gr.Textbox(label="Status", interactive=False) + + i2i_btn.click( + fn=lambda *args: asyncio.run(process_image_to_image(*args)), + inputs=[api_key, i2i_prompt, i2i_input, i2i_seed, i2i_strength], + outputs=[i2i_output, i2i_status], + ) + + # ================================================================ + # Video Processing Tab + # ================================================================ + with gr.TabItem("Video Generation"): + gr.Markdown("### Text to Video") + with gr.Row(): + with gr.Column(): + t2v_prompt = gr.Textbox( + label="Prompt", + placeholder="A cat walking in a park", + lines=3, + ) + with gr.Row(): + t2v_seed = gr.Number(label="Seed (optional)", precision=0) + t2v_enhance = gr.Checkbox(label="Enhance Prompt", value=True) + t2v_btn = gr.Button("Generate Video", variant="primary") + with gr.Column(): + t2v_output = gr.Video(label="Generated Video") + t2v_status = gr.Textbox(label="Status", interactive=False) + + t2v_btn.click( + fn=lambda *args: asyncio.run(process_video_t2v(*args)), + inputs=[api_key, t2v_prompt, t2v_seed, t2v_enhance], + outputs=[t2v_output, t2v_status], + ) + + gr.Markdown("---") + gr.Markdown("### Image to Video") + with gr.Row(): + with gr.Column(): + i2v_input = gr.Image(label="Input Image", type="filepath") + i2v_prompt = gr.Textbox( + label="Prompt", + placeholder="The scene comes to life", + lines=2, + ) + with gr.Row(): + i2v_seed = gr.Number(label="Seed (optional)", precision=0) + i2v_enhance = gr.Checkbox(label="Enhance Prompt", value=True) + i2v_btn = gr.Button("Generate Video", variant="primary") + with gr.Column(): + i2v_output = gr.Video(label="Generated Video") + i2v_status = gr.Textbox(label="Status", interactive=False) + + i2v_btn.click( + fn=lambda *args: asyncio.run(process_video_i2v(*args)), + inputs=[api_key, i2v_prompt, i2v_input, i2v_seed, i2v_enhance], + outputs=[i2v_output, i2v_status], + ) + + gr.Markdown("---") + gr.Markdown("### Video to Video") + with gr.Row(): + with gr.Column(): + v2v_input = gr.Video(label="Input Video") + v2v_prompt = gr.Textbox( + label="Prompt", + placeholder="Make it look like Lego world", + lines=2, + ) + with gr.Row(): + v2v_seed = gr.Number(label="Seed (optional)", precision=0) + v2v_enhance = gr.Checkbox(label="Enhance Prompt", value=True) + v2v_btn = gr.Button("Transform Video", variant="primary") + with gr.Column(): + v2v_output = gr.Video(label="Transformed Video") + v2v_status = gr.Textbox(label="Status", interactive=False) + + v2v_btn.click( + fn=lambda *args: asyncio.run(process_video_v2v(*args)), + inputs=[api_key, v2v_prompt, v2v_input, v2v_seed, v2v_enhance], + outputs=[v2v_output, v2v_status], + ) + + # ================================================================ + # Video Restyle Tab (NEW - with reference image support) + # ================================================================ + with gr.TabItem("Video Restyle (NEW)"): + gr.Markdown( + """ + ### Video Restyle with Prompt OR Reference Image + + This model supports two modes: + - **Text Prompt**: Describe the style you want + - **Reference Image**: Upload an image to use as style reference + """ + ) + + with gr.Row(): + with gr.Column(): + restyle_input = gr.Video(label="Input Video") + restyle_mode = gr.Checkbox( + label="Use Reference Image (instead of text prompt)", + value=False, + ) + restyle_prompt = gr.Textbox( + label="Prompt", + placeholder="Make it look like anime", + lines=2, + visible=True, + ) + restyle_ref_image = gr.Image( + label="Reference Image", + type="filepath", + visible=False, + ) + with gr.Row(): + restyle_seed = gr.Number(label="Seed (optional)", precision=0) + restyle_enhance = gr.Checkbox( + label="Enhance Prompt", + value=True, + visible=True, + ) + restyle_btn = gr.Button("Restyle Video", variant="primary") + with gr.Column(): + restyle_output = gr.Video(label="Restyled Video") + restyle_status = gr.Textbox(label="Status", interactive=False) + + # Toggle visibility based on mode + def toggle_mode(use_ref): + return ( + gr.update(visible=not use_ref), # prompt + gr.update(visible=use_ref), # ref image + gr.update(visible=not use_ref), # enhance + ) + + restyle_mode.change( + fn=toggle_mode, + inputs=[restyle_mode], + outputs=[restyle_prompt, restyle_ref_image, restyle_enhance], + ) + + restyle_btn.click( + fn=lambda *args: asyncio.run(process_video_restyle(*args)), + inputs=[ + api_key, + restyle_input, + restyle_mode, + restyle_prompt, + restyle_ref_image, + restyle_seed, + restyle_enhance, + ], + outputs=[restyle_output, restyle_status], + ) + + # ================================================================ + # Tokens Tab + # ================================================================ + with gr.TabItem("Tokens"): + gr.Markdown( + """ + ### Create Client Token + + Create a short-lived token for client-side use. + These tokens are meant for temporary access and expire automatically. + """ + ) + + with gr.Row(): + with gr.Column(): + token_btn = gr.Button("Create Token", variant="primary") + with gr.Column(): + token_output = gr.Textbox( + label="Result", + lines=5, + interactive=False, + ) + + token_btn.click( + fn=lambda key: asyncio.run(create_token(key)), + inputs=[api_key], + outputs=[token_output], + ) + + gr.Markdown( + """ + --- + **Note**: This UI uses the Decart Python SDK. + For realtime/WebRTC features, use the example scripts in `examples/`. + """ + ) + + return demo + + +if __name__ == "__main__": + demo = create_ui() + demo.launch( + server_name="127.0.0.1", #localhost only + server_port=7860, + share=False, + ) From 357585dc1f5ad6701161bc5f5e5656e5ba84a3fd Mon Sep 17 00:00:00 2001 From: Yehoav Rabinovich Date: Sun, 4 Jan 2026 11:45:02 +0200 Subject: [PATCH 8/9] feat: update avatar-live ACK format and add initial prompt support - Update image ACK message type from 'image_set' to 'set_image_ack' - Change ACK format to use success:bool + error:string - Add initial_prompt option for sending prompt before WebRTC handshake - Add ErrorMessage, ReadyMessage, and IceRestartMessage types - Add ICE restart with TURN server support - Update tests for new message format and initial_prompt feature --- decart/realtime/client.py | 17 +++- decart/realtime/messages.py | 40 +++++++++- decart/realtime/types.py | 13 ++++ decart/realtime/webrtc_connection.py | 112 ++++++++++++++++++++++++--- decart/realtime/webrtc_manager.py | 2 + tests/test_realtime_unit.py | 106 ++++++++++++++++++++++++- 6 files changed, 273 insertions(+), 17 deletions(-) diff --git a/decart/realtime/client.py b/decart/realtime/client.py index 5266a87..4411a56 100644 --- a/decart/realtime/client.py +++ b/decart/realtime/client.py @@ -83,8 +83,21 @@ async def connect( ) avatar_image_base64 = base64.b64encode(image_bytes).decode("utf-8") - await manager.connect(local_track, avatar_image_base64=avatar_image_base64) + # Prepare initial prompt if provided + initial_prompt: Optional[dict] = None + if options.initial_prompt: + initial_prompt = { + "text": options.initial_prompt.text, + "enhance": options.initial_prompt.enhance, + } + + await manager.connect( + local_track, + avatar_image_base64=avatar_image_base64, + initial_prompt=initial_prompt, + ) + # Handle initial_state.prompt for backward compatibility (after WebRTC connection) if options.initial_state: if options.initial_state.prompt: await client.set_prompt( @@ -167,7 +180,7 @@ async def set_image(self, image: FileInput) -> None: raise DecartSDKError("Image set acknowledgment timed out") if not result["success"]: - raise DecartSDKError(result.get("status") or "Failed to set avatar image") + raise DecartSDKError(result.get("error") or "Failed to set avatar image") finally: self._manager.unregister_image_set_wait() diff --git a/decart/realtime/messages.py b/decart/realtime/messages.py index 2b413df..287073f 100644 --- a/decart/realtime/messages.py +++ b/decart/realtime/messages.py @@ -51,11 +51,40 @@ class PromptAckMessage(BaseModel): error: Optional[str] = None -class ImageSetMessage(BaseModel): +class SetImageAckMessage(BaseModel): """Acknowledgment for avatar image set from server.""" - type: Literal["image_set"] - status: str + type: Literal["set_image_ack"] + success: bool + error: Optional[str] = None + + +class ErrorMessage(BaseModel): + """Error message from server.""" + + type: Literal["error"] + error: str + + +class ReadyMessage(BaseModel): + """Server ready signal.""" + + type: Literal["ready"] + + +class TurnConfig(BaseModel): + """TURN server configuration.""" + + username: str + credential: str + server_url: str + + +class IceRestartMessage(BaseModel): + """ICE restart message with TURN config.""" + + type: Literal["ice-restart"] + turn_config: TurnConfig # Discriminated union for incoming messages @@ -65,7 +94,10 @@ class ImageSetMessage(BaseModel): IceCandidateMessage, SessionIdMessage, PromptAckMessage, - ImageSetMessage, + SetImageAckMessage, + ErrorMessage, + ReadyMessage, + IceRestartMessage, ], Field(discriminator="type"), ] diff --git a/decart/realtime/types.py b/decart/realtime/types.py index cde9f59..51929b6 100644 --- a/decart/realtime/types.py +++ b/decart/realtime/types.py @@ -20,6 +20,17 @@ class AvatarOptions: """The avatar image to use. Can be bytes, Path, URL string, or file-like object.""" +@dataclass +class InitialPromptOptions: + """Options for initial prompt sent before WebRTC handshake.""" + + text: str + """The prompt text to send.""" + + enhance: bool = True + """Whether to enhance the prompt. Defaults to True.""" + + @dataclass class RealtimeConnectOptions: model: ModelDefinition @@ -27,3 +38,5 @@ class RealtimeConnectOptions: initial_state: Optional[ModelState] = None customize_offer: Optional[Callable] = None avatar: Optional[AvatarOptions] = None + initial_prompt: Optional[InitialPromptOptions] = None + """Initial prompt to send before WebRTC handshake (optional).""" diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index 3d62f66..4302eb1 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -21,9 +21,13 @@ OfferMessage, IceCandidateMessage, IceCandidatePayload, + PromptMessage, PromptAckMessage, - ImageSetMessage, + SetImageAckMessage, SetAvatarImageMessage, + ErrorMessage, + ReadyMessage, + IceRestartMessage, OutgoingMessage, ) from .types import ConnectionState @@ -60,6 +64,7 @@ async def connect( integration: Optional[str] = None, is_avatar_live: bool = False, avatar_image_base64: Optional[str] = None, + initial_prompt: Optional[dict] = None, ) -> None: try: await self._set_state("connecting") @@ -80,6 +85,10 @@ async def connect( if is_avatar_live and avatar_image_base64: await self._send_avatar_image_and_wait(avatar_image_base64) + # Send initial prompt before WebRTC handshake (if provided) + if initial_prompt: + await self._send_initial_prompt_and_wait(initial_prompt) + await self._setup_peer_connection(local_track, is_avatar_live=is_avatar_live) await self._create_and_send_offer() @@ -115,11 +124,39 @@ async def _send_avatar_image_and_wait(self, image_base64: str, timeout: float = if not result["success"]: raise WebRTCError( - f"Failed to set avatar image: {result.get('status', 'unknown error')}" + f"Failed to set avatar image: {result.get('error', 'unknown error')}" ) finally: self.unregister_image_set_wait() + async def _send_initial_prompt_and_wait( + self, prompt: dict, timeout: float = 15.0 + ) -> None: + """Send initial prompt and wait for acknowledgment before WebRTC handshake.""" + prompt_text = prompt.get("text", "") + enhance = prompt.get("enhance", True) + + event, result = self.register_prompt_wait(prompt_text) + + try: + await self._send_message( + PromptMessage( + type="prompt", prompt=prompt_text, enhance_prompt=enhance + ) + ) + + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + raise WebRTCError("Initial prompt acknowledgment timed out") + + if not result["success"]: + raise WebRTCError( + f"Failed to send initial prompt: {result.get('error', 'unknown error')}" + ) + finally: + self.unregister_prompt_wait(prompt_text) + async def _setup_peer_connection( self, local_track: Optional[MediaStreamTrack], @@ -220,8 +257,14 @@ async def _handle_message(self, data: dict) -> None: logger.debug(f"Session ID: {message.session_id}") elif message.type == "prompt_ack": self._handle_prompt_ack(message) - elif message.type == "image_set": - self._handle_image_set(message) + elif message.type == "set_image_ack": + self._handle_set_image_ack(message) + elif message.type == "error": + self._handle_error(message) + elif message.type == "ready": + logger.debug("Received ready signal from server") + elif message.type == "ice-restart": + await self._handle_ice_restart(message) async def _handle_answer(self, sdp: str) -> None: logger.debug("Received answer from server") @@ -261,17 +304,68 @@ def _handle_prompt_ack(self, message: PromptAckMessage) -> None: result["error"] = message.error event.set() - def _handle_image_set(self, message: ImageSetMessage) -> None: - logger.debug(f"Received image_set: status={message.status}") + def _handle_set_image_ack(self, message: SetImageAckMessage) -> None: + logger.debug(f"Received set_image_ack: success={message.success}, error={message.error}") if self._pending_image_set: event, result = self._pending_image_set - result["status"] = message.status - result["success"] = message.status == "success" + result["success"] = message.success + result["error"] = message.error event.set() + def _handle_error(self, message: ErrorMessage) -> None: + logger.error(f"Received error from server: {message.error}") + error = WebRTCError(message.error) + if self._on_error: + self._on_error(error) + + async def _handle_ice_restart(self, message: IceRestartMessage) -> None: + logger.info("Received ICE restart request from server") + turn_config = message.turn_config + # Re-setup peer connection with TURN server + await self._setup_peer_connection_with_turn(turn_config) + + async def _setup_peer_connection_with_turn(self, turn_config) -> None: + """Re-setup peer connection with TURN server for ICE restart.""" + from aiortc import RTCConfiguration, RTCIceServer + + ice_servers = [ + RTCIceServer(urls=["stun:stun.l.google.com:19302"]), + RTCIceServer( + urls=[turn_config.server_url], + username=turn_config.username, + credential=turn_config.credential, + ), + ] + config = RTCConfiguration(iceServers=ice_servers) + + # Close existing peer connection + if self._pc: + await self._pc.close() + + self._pc = RTCPeerConnection(configuration=config) + logger.debug("Re-created peer connection with TURN server for ICE restart") + + # Re-register callbacks + @self._pc.on("track") + def on_track(track: MediaStreamTrack): + logger.debug(f"Received remote track: {track.kind}") + if self._on_remote_stream: + self._on_remote_stream(track) + + @self._pc.on("connectionstatechange") + async def on_connection_state_change(): + logger.debug(f"Peer connection state: {self._pc.connectionState}") + if self._pc.connectionState == "connected": + await self._set_state("connected") + elif self._pc.connectionState in ["failed", "closed"]: + await self._set_state("disconnected") + + # Re-create and send offer + await self._create_and_send_offer() + def register_image_set_wait(self) -> tuple[asyncio.Event, dict]: event = asyncio.Event() - result: dict = {"success": False, "status": None} + result: dict = {"success": False, "error": None} self._pending_image_set = (event, result) return event, result diff --git a/decart/realtime/webrtc_manager.py b/decart/realtime/webrtc_manager.py index b9793ed..f9f0764 100644 --- a/decart/realtime/webrtc_manager.py +++ b/decart/realtime/webrtc_manager.py @@ -57,6 +57,7 @@ async def connect( self, local_track: Optional[MediaStreamTrack], avatar_image_base64: Optional[str] = None, + initial_prompt: Optional[dict] = None, ) -> bool: try: await self._connection.connect( @@ -65,6 +66,7 @@ async def connect( integration=self._config.integration, is_avatar_live=self._config.is_avatar_live, avatar_image_base64=avatar_image_base64, + initial_prompt=initial_prompt, ) return True except Exception as e: diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index 50f9cb3..15882ff 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -358,7 +358,7 @@ async def test_avatar_live_set_image(): mock_manager.send_message = AsyncMock() image_set_event = asyncio.Event() - image_set_result = {"success": True, "status": "success"} + image_set_result = {"success": True, "error": None} mock_manager.register_image_set_wait = MagicMock( return_value=(image_set_event, image_set_result) @@ -456,7 +456,7 @@ async def test_avatar_live_set_image_timeout(): mock_manager.send_message = AsyncMock() image_set_event = asyncio.Event() - image_set_result = {"success": False, "status": None} + image_set_result = {"success": False, "error": None} mock_manager.register_image_set_wait = MagicMock( return_value=(image_set_event, image_set_result) @@ -492,3 +492,105 @@ async def test_avatar_live_set_image_timeout(): assert "timed out" in str(exc_info.value).lower() mock_manager.unregister_image_set_wait.assert_called_once() + + +@pytest.mark.asyncio +async def test_avatar_live_set_image_server_error(): + """Test set_image raises on server error""" + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + image_set_event = asyncio.Event() + image_set_result = {"success": False, "error": "Invalid image format"} + + mock_manager.register_image_set_wait = MagicMock( + return_value=(image_set_event, image_set_result) + ) + mock_manager.unregister_image_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"image data", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + from decart.errors import DecartSDKError + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("avatar-live"), + on_remote_stream=lambda t: None, + ), + ) + + async def set_event(): + await asyncio.sleep(0.01) + image_set_event.set() + + asyncio.create_task(set_event()) + + with pytest.raises(DecartSDKError) as exc_info: + await realtime_client.set_image(b"test image") + + assert "Invalid image format" in str(exc_info.value) + mock_manager.unregister_image_set_wait.assert_called_once() + + +@pytest.mark.asyncio +async def test_connect_with_initial_prompt(): + """Test connection with initial_prompt option""" + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.is_connected = MagicMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions, InitialPromptOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("mirage"), + on_remote_stream=lambda t: None, + initial_prompt=InitialPromptOptions(text="Test prompt", enhance=False), + ), + ) + + assert realtime_client is not None + mock_manager.connect.assert_called_once() + call_kwargs = mock_manager.connect.call_args[1] + assert "initial_prompt" in call_kwargs + assert call_kwargs["initial_prompt"] == {"text": "Test prompt", "enhance": False} From 29576b7c31077fbf538e2b4355994b79c44bf7c7 Mon Sep 17 00:00:00 2001 From: Yehoav Rabinovich Date: Sun, 4 Jan 2026 12:08:25 +0200 Subject: [PATCH 9/9] ruff and black --- decart/realtime/webrtc_connection.py | 9 ++------- test_ui.py | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index 4302eb1..f99d14a 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -26,7 +26,6 @@ SetImageAckMessage, SetAvatarImageMessage, ErrorMessage, - ReadyMessage, IceRestartMessage, OutgoingMessage, ) @@ -129,9 +128,7 @@ async def _send_avatar_image_and_wait(self, image_base64: str, timeout: float = finally: self.unregister_image_set_wait() - async def _send_initial_prompt_and_wait( - self, prompt: dict, timeout: float = 15.0 - ) -> None: + async def _send_initial_prompt_and_wait(self, prompt: dict, timeout: float = 15.0) -> None: """Send initial prompt and wait for acknowledgment before WebRTC handshake.""" prompt_text = prompt.get("text", "") enhance = prompt.get("enhance", True) @@ -140,9 +137,7 @@ async def _send_initial_prompt_and_wait( try: await self._send_message( - PromptMessage( - type="prompt", prompt=prompt_text, enhance_prompt=enhance - ) + PromptMessage(type="prompt", prompt=prompt_text, enhance_prompt=enhance) ) try: diff --git a/test_ui.py b/test_ui.py index 89cf319..e51caeb 100644 --- a/test_ui.py +++ b/test_ui.py @@ -606,7 +606,7 @@ def toggle_mode(use_ref): if __name__ == "__main__": demo = create_ui() demo.launch( - server_name="127.0.0.1", #localhost only + server_name="127.0.0.1", # localhost only server_port=7860, share=False, )