diff --git a/README.md b/README.md index b292ad6..dbfe4e8 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,28 @@ uv run python examples/realtime_synthetic.py uv lock --upgrade ``` +### Test UI + +The SDK includes an interactive test UI built with Gradio for quickly testing all SDK features without writing code. + +```bash +# Install Gradio +pip install gradio + +# Run the test UI +python test_ui.py +``` + +Then open http://localhost:7860 in your browser. + +The UI provides tabs for: +- **Image Generation** - Text-to-image and image-to-image transformations +- **Video Generation** - Text-to-video, image-to-video, and video-to-video +- **Video Restyle** - Restyle videos using text prompts or reference images +- **Tokens** - Create short-lived client tokens + +Enter your API key at the top of the interface to start testing. + ### Publishing a New Version The package is automatically published to PyPI when you create a GitHub release. diff --git a/decart/__init__.py b/decart/__init__.py index 988bc73..452a3dc 100644 --- a/decart/__init__.py +++ b/decart/__init__.py @@ -12,7 +12,7 @@ QueueResultError, TokenCreateError, ) -from .models import models, ModelDefinition +from .models import models, ModelDefinition, VideoRestyleInput from .types import FileInput, ModelState, Prompt from .queue import ( QueueClient, @@ -31,6 +31,7 @@ RealtimeClient, RealtimeConnectOptions, ConnectionState, + AvatarOptions, ) REALTIME_AVAILABLE = True @@ -39,6 +40,7 @@ RealtimeClient = None # type: ignore RealtimeConnectOptions = None # type: ignore ConnectionState = None # type: ignore + AvatarOptions = None # type: ignore __version__ = "0.0.1" @@ -56,6 +58,7 @@ "QueueResultError", "models", "ModelDefinition", + "VideoRestyleInput", "FileInput", "ModelState", "Prompt", @@ -75,5 +78,6 @@ "RealtimeClient", "RealtimeConnectOptions", "ConnectionState", + "AvatarOptions", ] ) diff --git a/decart/models.py b/decart/models.py index fa87f0f..9d7aa58 100644 --- a/decart/models.py +++ b/decart/models.py @@ -1,10 +1,10 @@ from typing import Literal, Optional, List, Generic, TypeVar -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, Field, ConfigDict, model_validator from .errors import ModelNotFoundError from .types import FileInput, MotionTrajectoryInput -RealTimeModels = Literal["mirage", "mirage_v2", "lucy_v2v_720p_rt"] +RealTimeModels = Literal["mirage", "mirage_v2", "lucy_v2v_720p_rt", "avatar-live"] VideoModels = Literal[ "lucy-dev-i2v", "lucy-fast-v2v", @@ -13,6 +13,7 @@ "lucy-pro-v2v", "lucy-pro-flf2v", "lucy-motion", + "lucy-restyle-v2v", ] ImageModels = Literal["lucy-pro-t2i", "lucy-pro-i2i"] Model = Literal[RealTimeModels, VideoModels, ImageModels] @@ -95,6 +96,36 @@ class ImageToMotionVideoInput(DecartBaseModel): resolution: Optional[str] = None +class VideoRestyleInput(DecartBaseModel): + """Input for lucy-restyle-v2v model. + + Must provide either `prompt` OR `reference_image`, but not both. + `enhance_prompt` is only valid when using `prompt`, not `reference_image`. + """ + + prompt: Optional[str] = Field(default=None, min_length=1, max_length=1000) + reference_image: Optional[FileInput] = None + data: FileInput + seed: Optional[int] = None + resolution: Optional[str] = None + enhance_prompt: Optional[bool] = None + + @model_validator(mode="after") + def validate_prompt_or_reference_image(self) -> "VideoRestyleInput": + has_prompt = self.prompt is not None + has_reference_image = self.reference_image is not None + + if has_prompt == has_reference_image: + raise ValueError("Must provide either 'prompt' or 'reference_image', but not both") + + if has_reference_image and self.enhance_prompt is not None: + raise ValueError( + "'enhance_prompt' is only valid when using 'prompt', not 'reference_image'" + ) + + return self + + class TextToImageInput(BaseModel): prompt: str = Field( ..., @@ -144,6 +175,14 @@ class ImageToImageInput(DecartBaseModel): height=704, input_schema=BaseModel, ), + "avatar-live": ModelDefinition( + name="avatar-live", + url_path="/v1/avatar-live/stream", + fps=25, + width=1280, + height=720, + input_schema=BaseModel, + ), }, "video": { "lucy-dev-i2v": ModelDefinition( @@ -202,6 +241,14 @@ class ImageToImageInput(DecartBaseModel): height=704, input_schema=ImageToMotionVideoInput, ), + "lucy-restyle-v2v": ModelDefinition( + name="lucy-restyle-v2v", + url_path="/v1/generate/lucy-restyle-v2v", + fps=25, + width=1280, + height=704, + input_schema=VideoRestyleInput, + ), }, "image": { "lucy-pro-t2i": ModelDefinition( @@ -247,6 +294,7 @@ def video(model: VideoModels) -> VideoModelDefinition: - "lucy-dev-i2v" - Image-to-video (Dev quality) - "lucy-fast-v2v" - Video-to-video (Fast quality) - "lucy-motion" - Image-to-motion-video + - "lucy-restyle-v2v" - Video-to-video with prompt or reference image """ try: return _MODELS["video"][model] # type: ignore[return-value] diff --git a/decart/queue/request.py b/decart/queue/request.py index cca60b6..c242b62 100644 --- a/decart/queue/request.py +++ b/decart/queue/request.py @@ -24,7 +24,7 @@ async def submit_job( for key, value in inputs.items(): if value is not None: - if key in ("data", "start", "end"): + if key in ("data", "start", "end", "reference_image"): content, content_type = await file_input_to_bytes(value, session) form_data.add_field(key, content, content_type=content_type) else: diff --git a/decart/realtime/__init__.py b/decart/realtime/__init__.py index fd681ec..cb20fd2 100644 --- a/decart/realtime/__init__.py +++ b/decart/realtime/__init__.py @@ -1,8 +1,9 @@ from .client import RealtimeClient -from .types import RealtimeConnectOptions, ConnectionState +from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions __all__ = [ "RealtimeClient", "RealtimeConnectOptions", "ConnectionState", + "AvatarOptions", ] diff --git a/decart/realtime/client.py b/decart/realtime/client.py index 37d79b2..4411a56 100644 --- a/decart/realtime/client.py +++ b/decart/realtime/client.py @@ -1,21 +1,33 @@ from typing import Callable, Optional import asyncio +import base64 import logging import uuid +import aiohttp from aiortc import MediaStreamTrack from .webrtc_manager import WebRTCManager, WebRTCConfiguration -from .messages import PromptMessage +from .messages import PromptMessage, SetAvatarImageMessage from .types import ConnectionState, RealtimeConnectOptions +from ..types import FileInput from ..errors import DecartSDKError, InvalidInputError, WebRTCError +from ..process.request import file_input_to_bytes logger = logging.getLogger(__name__) class RealtimeClient: - def __init__(self, manager: WebRTCManager, session_id: str): + def __init__( + self, + manager: WebRTCManager, + session_id: str, + http_session: Optional[aiohttp.ClientSession] = None, + is_avatar_live: bool = False, + ): self._manager = manager self.session_id = session_id + self._http_session = http_session + self._is_avatar_live = is_avatar_live self._connection_callbacks: list[Callable[[ConnectionState], None]] = [] self._error_callbacks: list[Callable[[DecartSDKError], None]] = [] @@ -24,7 +36,7 @@ async def connect( cls, base_url: str, api_key: str, - local_track: MediaStreamTrack, + local_track: Optional[MediaStreamTrack], options: RealtimeConnectOptions, integration: Optional[str] = None, ) -> "RealtimeClient": @@ -32,6 +44,8 @@ async def connect( ws_url = f"{base_url}{options.model.url_path}" ws_url += f"?api_key={api_key}&model={options.model.name}" + is_avatar_live = options.model.name == "avatar-live" + config = WebRTCConfiguration( webrtc_url=ws_url, api_key=api_key, @@ -43,17 +57,47 @@ async def connect( initial_state=options.initial_state, customize_offer=options.customize_offer, integration=integration, + is_avatar_live=is_avatar_live, ) + # Create HTTP session for file conversions + http_session = aiohttp.ClientSession() + manager = WebRTCManager(config) - client = cls(manager=manager, session_id=session_id) + client = cls( + manager=manager, + session_id=session_id, + http_session=http_session, + is_avatar_live=is_avatar_live, + ) config.on_connection_state_change = client._emit_connection_change config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error)) try: - await manager.connect(local_track) + # For avatar-live, convert and send avatar image before WebRTC connection + avatar_image_base64: Optional[str] = None + if is_avatar_live and options.avatar: + image_bytes, _ = await file_input_to_bytes( + options.avatar.avatar_image, http_session + ) + avatar_image_base64 = base64.b64encode(image_bytes).decode("utf-8") + + # Prepare initial prompt if provided + initial_prompt: Optional[dict] = None + if options.initial_prompt: + initial_prompt = { + "text": options.initial_prompt.text, + "enhance": options.initial_prompt.enhance, + } + + await manager.connect( + local_track, + avatar_image_base64=avatar_image_base64, + initial_prompt=initial_prompt, + ) + # Handle initial_state.prompt for backward compatibility (after WebRTC connection) if options.initial_state: if options.initial_state.prompt: await client.set_prompt( @@ -61,6 +105,7 @@ async def connect( enrich=options.initial_state.prompt.enrich, ) except Exception as e: + await http_session.close() raise WebRTCError(str(e), cause=e) return client @@ -100,6 +145,45 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None: finally: self._manager.unregister_prompt_wait(prompt) + async def set_image(self, image: FileInput) -> None: + """Set or update the avatar image. + + Only available for avatar-live model. + + Args: + image: The image to set. Can be bytes, Path, URL string, or file-like object. + + Raises: + InvalidInputError: If not using avatar-live model or image is invalid. + DecartSDKError: If the server fails to acknowledge the image. + """ + if not self._is_avatar_live: + raise InvalidInputError("set_image() is only available for avatar-live model") + + if not self._http_session: + raise InvalidInputError("HTTP session not available") + + # Convert image to base64 + image_bytes, _ = await file_input_to_bytes(image, self._http_session) + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + + event, result = self._manager.register_image_set_wait() + + try: + await self._manager.send_message( + SetAvatarImageMessage(type="set_image", image_data=image_base64) + ) + + try: + await asyncio.wait_for(event.wait(), timeout=15.0) + except asyncio.TimeoutError: + raise DecartSDKError("Image set acknowledgment timed out") + + if not result["success"]: + raise DecartSDKError(result.get("error") or "Failed to set avatar image") + finally: + self._manager.unregister_image_set_wait() + def is_connected(self) -> bool: return self._manager.is_connected() @@ -108,6 +192,8 @@ def get_connection_state(self) -> ConnectionState: async def disconnect(self) -> None: await self._manager.cleanup() + if self._http_session and not self._http_session.closed: + await self._http_session.close() def on(self, event: str, callback: Callable) -> None: if event == "connection_change": diff --git a/decart/realtime/messages.py b/decart/realtime/messages.py index c59c5c6..287073f 100644 --- a/decart/realtime/messages.py +++ b/decart/realtime/messages.py @@ -51,9 +51,54 @@ class PromptAckMessage(BaseModel): error: Optional[str] = None +class SetImageAckMessage(BaseModel): + """Acknowledgment for avatar image set from server.""" + + type: Literal["set_image_ack"] + success: bool + error: Optional[str] = None + + +class ErrorMessage(BaseModel): + """Error message from server.""" + + type: Literal["error"] + error: str + + +class ReadyMessage(BaseModel): + """Server ready signal.""" + + type: Literal["ready"] + + +class TurnConfig(BaseModel): + """TURN server configuration.""" + + username: str + credential: str + server_url: str + + +class IceRestartMessage(BaseModel): + """ICE restart message with TURN config.""" + + type: Literal["ice-restart"] + turn_config: TurnConfig + + # Discriminated union for incoming messages IncomingMessage = Annotated[ - Union[AnswerMessage, IceCandidateMessage, SessionIdMessage, PromptAckMessage], + Union[ + AnswerMessage, + IceCandidateMessage, + SessionIdMessage, + PromptAckMessage, + SetImageAckMessage, + ErrorMessage, + ReadyMessage, + IceRestartMessage, + ], Field(discriminator="type"), ] @@ -79,8 +124,15 @@ class PromptMessage(BaseModel): enhance_prompt: bool = True +class SetAvatarImageMessage(BaseModel): + """Set avatar image message.""" + + type: Literal["set_image"] + image_data: str # Base64-encoded image + + # Outgoing message union (no discriminator needed - we know what we're sending) -OutgoingMessage = Union[OfferMessage, IceCandidateMessage, PromptMessage] +OutgoingMessage = Union[OfferMessage, IceCandidateMessage, PromptMessage, SetAvatarImageMessage] def parse_incoming_message(data: dict) -> IncomingMessage: diff --git a/decart/realtime/types.py b/decart/realtime/types.py index 7234e24..51929b6 100644 --- a/decart/realtime/types.py +++ b/decart/realtime/types.py @@ -1,7 +1,7 @@ from typing import Literal, Callable, Optional from dataclasses import dataclass from ..models import ModelDefinition -from ..types import ModelState +from ..types import ModelState, FileInput try: from aiortc import MediaStreamTrack @@ -12,9 +12,31 @@ ConnectionState = Literal["connecting", "connected", "disconnected"] +@dataclass +class AvatarOptions: + """Options for avatar-live model.""" + + avatar_image: FileInput + """The avatar image to use. Can be bytes, Path, URL string, or file-like object.""" + + +@dataclass +class InitialPromptOptions: + """Options for initial prompt sent before WebRTC handshake.""" + + text: str + """The prompt text to send.""" + + enhance: bool = True + """Whether to enhance the prompt. Defaults to True.""" + + @dataclass class RealtimeConnectOptions: model: ModelDefinition on_remote_stream: Callable[[MediaStreamTrack], None] initial_state: Optional[ModelState] = None customize_offer: Optional[Callable] = None + avatar: Optional[AvatarOptions] = None + initial_prompt: Optional[InitialPromptOptions] = None + """Initial prompt to send before WebRTC handshake (optional).""" diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index ad087e4..f99d14a 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -21,7 +21,12 @@ OfferMessage, IceCandidateMessage, IceCandidatePayload, + PromptMessage, PromptAckMessage, + SetImageAckMessage, + SetAvatarImageMessage, + ErrorMessage, + IceRestartMessage, OutgoingMessage, ) from .types import ConnectionState @@ -48,13 +53,17 @@ def __init__( self._ws_task: Optional[asyncio.Task] = None self._ice_candidates_queue: list[RTCIceCandidate] = [] self._pending_prompts: dict[str, tuple[asyncio.Event, dict]] = {} + self._pending_image_set: Optional[tuple[asyncio.Event, dict]] = None async def connect( self, url: str, - local_track: MediaStreamTrack, + local_track: Optional[MediaStreamTrack], timeout: float = 30, integration: Optional[str] = None, + is_avatar_live: bool = False, + avatar_image_base64: Optional[str] = None, + initial_prompt: Optional[dict] = None, ) -> None: try: await self._set_state("connecting") @@ -71,7 +80,15 @@ async def connect( self._ws_task = asyncio.create_task(self._receive_messages()) - await self._setup_peer_connection(local_track) + # For avatar-live, send avatar image before WebRTC handshake + if is_avatar_live and avatar_image_base64: + await self._send_avatar_image_and_wait(avatar_image_base64) + + # Send initial prompt before WebRTC handshake (if provided) + if initial_prompt: + await self._send_initial_prompt_and_wait(initial_prompt) + + await self._setup_peer_connection(local_track, is_avatar_live=is_avatar_live) await self._create_and_send_offer() @@ -90,7 +107,56 @@ async def connect( self._on_error(e) raise WebRTCError(str(e), cause=e) - async def _setup_peer_connection(self, local_track: MediaStreamTrack) -> None: + async def _send_avatar_image_and_wait(self, image_base64: str, timeout: float = 15.0) -> None: + """Send avatar image and wait for acknowledgment.""" + event, result = self.register_image_set_wait() + + try: + await self._send_message( + SetAvatarImageMessage(type="set_image", image_data=image_base64) + ) + + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + raise WebRTCError("Avatar image acknowledgment timed out") + + if not result["success"]: + raise WebRTCError( + f"Failed to set avatar image: {result.get('error', 'unknown error')}" + ) + finally: + self.unregister_image_set_wait() + + async def _send_initial_prompt_and_wait(self, prompt: dict, timeout: float = 15.0) -> None: + """Send initial prompt and wait for acknowledgment before WebRTC handshake.""" + prompt_text = prompt.get("text", "") + enhance = prompt.get("enhance", True) + + event, result = self.register_prompt_wait(prompt_text) + + try: + await self._send_message( + PromptMessage(type="prompt", prompt=prompt_text, enhance_prompt=enhance) + ) + + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + raise WebRTCError("Initial prompt acknowledgment timed out") + + if not result["success"]: + raise WebRTCError( + f"Failed to send initial prompt: {result.get('error', 'unknown error')}" + ) + finally: + self.unregister_prompt_wait(prompt_text) + + async def _setup_peer_connection( + self, + local_track: Optional[MediaStreamTrack], + is_avatar_live: bool = False, + ) -> None: config = RTCConfiguration(iceServers=[RTCIceServer(urls=["stun:stun.l.google.com:19302"])]) self._pc = RTCPeerConnection(configuration=config) @@ -128,8 +194,15 @@ async def on_connection_state_change(): async def on_ice_connection_state_change(): logger.debug(f"ICE connection state: {self._pc.iceConnectionState}") - self._pc.addTrack(local_track) - logger.debug("Added local track to peer connection") + # For avatar-live, add recv-only video transceiver + if is_avatar_live: + self._pc.addTransceiver("video", direction="recvonly") + logger.debug("Added video transceiver (recvonly) for avatar-live") + + # Add local audio track if provided + if local_track: + self._pc.addTrack(local_track) + logger.debug("Added local track to peer connection") async def _create_and_send_offer(self) -> None: logger.debug("Creating offer...") @@ -179,6 +252,14 @@ async def _handle_message(self, data: dict) -> None: logger.debug(f"Session ID: {message.session_id}") elif message.type == "prompt_ack": self._handle_prompt_ack(message) + elif message.type == "set_image_ack": + self._handle_set_image_ack(message) + elif message.type == "error": + self._handle_error(message) + elif message.type == "ready": + logger.debug("Received ready signal from server") + elif message.type == "ice-restart": + await self._handle_ice_restart(message) async def _handle_answer(self, sdp: str) -> None: logger.debug("Received answer from server") @@ -218,6 +299,74 @@ def _handle_prompt_ack(self, message: PromptAckMessage) -> None: result["error"] = message.error event.set() + def _handle_set_image_ack(self, message: SetImageAckMessage) -> None: + logger.debug(f"Received set_image_ack: success={message.success}, error={message.error}") + if self._pending_image_set: + event, result = self._pending_image_set + result["success"] = message.success + result["error"] = message.error + event.set() + + def _handle_error(self, message: ErrorMessage) -> None: + logger.error(f"Received error from server: {message.error}") + error = WebRTCError(message.error) + if self._on_error: + self._on_error(error) + + async def _handle_ice_restart(self, message: IceRestartMessage) -> None: + logger.info("Received ICE restart request from server") + turn_config = message.turn_config + # Re-setup peer connection with TURN server + await self._setup_peer_connection_with_turn(turn_config) + + async def _setup_peer_connection_with_turn(self, turn_config) -> None: + """Re-setup peer connection with TURN server for ICE restart.""" + from aiortc import RTCConfiguration, RTCIceServer + + ice_servers = [ + RTCIceServer(urls=["stun:stun.l.google.com:19302"]), + RTCIceServer( + urls=[turn_config.server_url], + username=turn_config.username, + credential=turn_config.credential, + ), + ] + config = RTCConfiguration(iceServers=ice_servers) + + # Close existing peer connection + if self._pc: + await self._pc.close() + + self._pc = RTCPeerConnection(configuration=config) + logger.debug("Re-created peer connection with TURN server for ICE restart") + + # Re-register callbacks + @self._pc.on("track") + def on_track(track: MediaStreamTrack): + logger.debug(f"Received remote track: {track.kind}") + if self._on_remote_stream: + self._on_remote_stream(track) + + @self._pc.on("connectionstatechange") + async def on_connection_state_change(): + logger.debug(f"Peer connection state: {self._pc.connectionState}") + if self._pc.connectionState == "connected": + await self._set_state("connected") + elif self._pc.connectionState in ["failed", "closed"]: + await self._set_state("disconnected") + + # Re-create and send offer + await self._create_and_send_offer() + + def register_image_set_wait(self) -> tuple[asyncio.Event, dict]: + event = asyncio.Event() + result: dict = {"success": False, "error": None} + self._pending_image_set = (event, result) + return event, result + + def unregister_image_set_wait(self) -> None: + self._pending_image_set = None + def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]: event = asyncio.Event() result: dict = {"success": False, "error": None} diff --git a/decart/realtime/webrtc_manager.py b/decart/realtime/webrtc_manager.py index d39a067..f9f0764 100644 --- a/decart/realtime/webrtc_manager.py +++ b/decart/realtime/webrtc_manager.py @@ -31,6 +31,7 @@ class WebRTCConfiguration: initial_state: Optional[ModelState] = None customize_offer: Optional[Callable] = None integration: Optional[str] = None + is_avatar_live: bool = False def _is_retryable_error(exception: Exception) -> bool: @@ -52,12 +53,20 @@ def __init__(self, configuration: WebRTCConfiguration): before_sleep=before_sleep_log(logger, logging.WARNING), reraise=True, ) - async def connect(self, local_track: MediaStreamTrack) -> bool: + async def connect( + self, + local_track: Optional[MediaStreamTrack], + avatar_image_base64: Optional[str] = None, + initial_prompt: Optional[dict] = None, + ) -> bool: try: await self._connection.connect( url=self._config.webrtc_url, local_track=local_track, integration=self._config.integration, + is_avatar_live=self._config.is_avatar_live, + avatar_image_base64=avatar_image_base64, + initial_prompt=initial_prompt, ) return True except Exception as e: @@ -91,3 +100,9 @@ def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]: def unregister_prompt_wait(self, prompt: str) -> None: self._connection.unregister_prompt_wait(prompt) + + def register_image_set_wait(self) -> tuple[asyncio.Event, dict]: + return self._connection.register_image_set_wait() + + def unregister_image_set_wait(self) -> None: + self._connection.unregister_image_set_wait() diff --git a/examples/avatar_live.py b/examples/avatar_live.py new file mode 100644 index 0000000..bebb911 --- /dev/null +++ b/examples/avatar_live.py @@ -0,0 +1,180 @@ +""" +Avatar Live Example + +This example demonstrates how to use the avatar-live model to animate an avatar image. +The avatar can be animated with audio input (microphone or audio file). + +Usage: + # With audio file: + DECART_API_KEY=your-key python avatar_live.py avatar.png audio.mp3 + + # With just avatar image (will wait for audio): + DECART_API_KEY=your-key python avatar_live.py avatar.png + +Requirements: + pip install decart[realtime] +""" + +import asyncio +import os +import sys +from pathlib import Path + +try: + from aiortc.contrib.media import MediaPlayer, MediaRecorder +except ImportError: + print("aiortc is required for this example.") + print("Install with: pip install decart[realtime]") + sys.exit(1) + +from decart import DecartClient, models + + +async def main(): + api_key = os.getenv("DECART_API_KEY") + if not api_key: + print("Error: DECART_API_KEY environment variable not set") + print("Usage: DECART_API_KEY=your-key python avatar_live.py [audio_file]") + return + + if len(sys.argv) < 2: + print("Usage: python avatar_live.py [audio_file]") + print("") + print("Arguments:") + print(" avatar_image - Path to avatar image (PNG, JPG)") + print(" audio_file - Optional: Path to audio file for the avatar to speak") + print("") + print("Examples:") + print(" python avatar_live.py avatar.png") + print(" python avatar_live.py avatar.png speech.mp3") + return + + avatar_image = sys.argv[1] + if not os.path.exists(avatar_image): + print(f"Error: Avatar image not found: {avatar_image}") + return + + audio_file = sys.argv[2] if len(sys.argv) > 2 else None + if audio_file and not os.path.exists(audio_file): + print(f"Error: Audio file not found: {audio_file}") + return + + print(f"šŸ–¼ļø Avatar image: {avatar_image}") + if audio_file: + print(f"šŸ”Š Audio file: {audio_file}") + + # Load audio if provided + audio_track = None + if audio_file: + print("Loading audio file...") + player = MediaPlayer(audio_file) + if player.audio: + audio_track = player.audio + print("āœ“ Audio loaded") + else: + print("āš ļø Warning: No audio stream found in file, continuing without audio") + + try: + from decart.realtime.client import RealtimeClient + from decart.realtime.types import RealtimeConnectOptions, AvatarOptions + except ImportError: + print("Error: Realtime API not available") + print("Install with: pip install decart[realtime]") + return + + print("\nCreating Decart client...") + async with DecartClient(api_key=api_key) as client: + model = models.realtime("avatar-live") + print(f"Using model: {model.name}") + + frame_count = 0 + recorder = None + output_file = Path("output_avatar_live.mp4") + + def on_remote_stream(track): + nonlocal frame_count, recorder + frame_count += 1 + if frame_count % 25 == 0: + print(f"šŸ“¹ Received {frame_count} frames...") + + if recorder is None: + print(f"šŸ’¾ Recording to {output_file}") + recorder = MediaRecorder(str(output_file)) + recorder.addTrack(track) + asyncio.create_task(recorder.start()) + + def on_connection_change(state): + print(f"šŸ”„ Connection state: {state}") + + def on_error(error): + print(f"āŒ Error: {error.__class__.__name__} - {error.message}") + + print("\nConnecting to Avatar Live API...") + print("(Sending avatar image...)") + + try: + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=audio_track, # Can be None if no audio + options=RealtimeConnectOptions( + model=model, + on_remote_stream=on_remote_stream, + avatar=AvatarOptions(avatar_image=Path(avatar_image)), + ), + ) + + realtime_client.on("connection_change", on_connection_change) + realtime_client.on("error", on_error) + + print("āœ“ Connected!") + print(f"Session ID: {realtime_client.session_id}") + + if audio_file: + print("\nPlaying audio through avatar...") + print("(The avatar will animate based on the audio)") + else: + print("\nNo audio provided - avatar will be static") + print("You can update the avatar image dynamically using set_image()") + + print("\nPress Ctrl+C to stop and save the recording...") + + # Demo: Update avatar image after 5 seconds (if you want to test set_image) + # Uncomment the following to test dynamic image updates: + # await asyncio.sleep(5) + # print("Updating avatar image...") + # await realtime_client.set_image(Path("new_avatar.png")) + # print("āœ“ Avatar image updated!") + + try: + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + print(f"\n\nāœ“ Received {frame_count} frames total") + + finally: + if recorder: + try: + print(f"šŸ’¾ Saving video to {output_file}...") + await asyncio.sleep(0.5) + await recorder.stop() + print(f"āœ“ Video saved to {output_file}") + except Exception as e: + print(f"āš ļø Warning: Could not save video cleanly: {e}") + print(" Video file may be incomplete") + + except Exception as e: + print(f"\nāŒ Connection failed: {e}") + import traceback + + traceback.print_exc() + + finally: + if "realtime_client" in locals(): + print("\nDisconnecting...") + await realtime_client.disconnect() + print("āœ“ Disconnected") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/video_restyle.py b/examples/video_restyle.py new file mode 100644 index 0000000..dd639d1 --- /dev/null +++ b/examples/video_restyle.py @@ -0,0 +1,139 @@ +""" +Video Restyle Example + +This example demonstrates how to use the lucy-restyle-v2v model to restyle a video +using either a text prompt OR a reference image. + +Usage: + # With text prompt: + DECART_API_KEY=your-key python video_restyle.py input.mp4 --prompt "anime style" + + # With reference image: + DECART_API_KEY=your-key python video_restyle.py input.mp4 --reference style.png + +Requirements: + pip install decart +""" + +import asyncio +import argparse +import os +import sys +from pathlib import Path + +from decart import DecartClient, models + + +async def main(): + parser = argparse.ArgumentParser( + description="Restyle a video using text prompt or reference image" + ) + parser.add_argument("video", help="Path to input video file") + parser.add_argument( + "--prompt", + "-p", + help="Text prompt describing the style (e.g., 'anime style', 'oil painting')", + ) + parser.add_argument("--reference", "-r", help="Path to reference image for style transfer") + parser.add_argument("--output", "-o", help="Output file path (default: output_restyle.mp4)") + parser.add_argument("--seed", "-s", type=int, help="Random seed for reproducibility") + parser.add_argument( + "--enhance", + action="store_true", + default=True, + help="Enhance the prompt (only with --prompt, default: True)", + ) + parser.add_argument("--no-enhance", action="store_true", help="Disable prompt enhancement") + + args = parser.parse_args() + + # Validate arguments + if not args.prompt and not args.reference: + print("Error: Must provide either --prompt or --reference") + parser.print_help() + sys.exit(1) + + if args.prompt and args.reference: + print("Error: Cannot use both --prompt and --reference") + print(" Please choose one or the other") + sys.exit(1) + + api_key = os.getenv("DECART_API_KEY") + if not api_key: + print("Error: DECART_API_KEY environment variable not set") + sys.exit(1) + + video_path = Path(args.video) + if not video_path.exists(): + print(f"Error: Video file not found: {video_path}") + sys.exit(1) + + if args.reference: + ref_path = Path(args.reference) + if not ref_path.exists(): + print(f"Error: Reference image not found: {ref_path}") + sys.exit(1) + + output_path = args.output or f"output_restyle_{video_path.stem}.mp4" + + print("=" * 50) + print("Video Restyle") + print("=" * 50) + print(f"Input video: {video_path}") + if args.prompt: + print(f"Style: Text prompt - '{args.prompt}'") + print(f"Enhance prompt: {not args.no_enhance}") + else: + print(f"Style: Reference image - {args.reference}") + print(f"Output: {output_path}") + if args.seed: + print(f"Seed: {args.seed}") + print("=" * 50) + + async with DecartClient(api_key=api_key) as client: + # Build options + options = { + "model": models.video("lucy-restyle-v2v"), + "data": video_path, + } + + if args.prompt: + options["prompt"] = args.prompt + options["enhance_prompt"] = not args.no_enhance + else: + options["reference_image"] = Path(args.reference) + + if args.seed: + options["seed"] = args.seed + + def on_status_change(job): + status_emoji = { + "pending": "ā³", + "processing": "šŸ”„", + "completed": "āœ…", + "failed": "āŒ", + } + emoji = status_emoji.get(job.status, "•") + print(f"{emoji} Status: {job.status}") + + options["on_status_change"] = on_status_change + + print("\nSubmitting job...") + result = await client.queue.submit_and_poll(options) + + if result.status == "failed": + print(f"\nāŒ Job failed: {result.error}") + sys.exit(1) + + print("\nāœ… Job completed!") + print(f"šŸ’¾ Saving to {output_path}...") + + with open(output_path, "wb") as f: + f.write(result.data) + + print(f"āœ“ Video saved to {output_path}") + print(f" Size: {len(result.data) / 1024 / 1024:.2f} MB") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/test_ui.py b/test_ui.py new file mode 100644 index 0000000..e51caeb --- /dev/null +++ b/test_ui.py @@ -0,0 +1,612 @@ +#!/usr/bin/env python3 +""" +Decart SDK Test UI - Interactive testing interface for the Python SDK. + +Usage: + pip install gradio + python test_ui.py + +Then open http://localhost:7860 in your browser. +""" + +import asyncio +import gradio as gr +from pathlib import Path +from typing import Optional +import tempfile +import os + +# Import the SDK +from decart import DecartClient, models + + +def get_client(api_key: str) -> DecartClient: + """Create a Decart client with the given API key.""" + if not api_key or not api_key.strip(): + raise ValueError("Please enter an API key") + return DecartClient(api_key=api_key.strip()) + + +# ============================================================================ +# Image Processing (Process API) +# ============================================================================ + + +async def process_text_to_image( + api_key: str, + prompt: str, + seed: Optional[int], + resolution: str, + orientation: str, +) -> tuple[Optional[bytes], str]: + """Generate an image from text prompt.""" + try: + client = get_client(api_key) + + options = { + "model": models.image("lucy-pro-t2i"), + "prompt": prompt, + } + if seed: + options["seed"] = seed + if resolution and resolution != "default": + options["resolution"] = resolution + if orientation and orientation != "default": + options["orientation"] = orientation + + result = await client.process(options) + return result, f"Success! Generated image from prompt: '{prompt[:50]}...'" + except Exception as e: + return None, f"Error: {str(e)}" + + +async def process_image_to_image( + api_key: str, + prompt: str, + input_image: str, + seed: Optional[int], + strength: float, +) -> tuple[Optional[bytes], str]: + """Transform an image with a prompt.""" + try: + if not input_image: + return None, "Please upload an image" + + client = get_client(api_key) + + options = { + "model": models.image("lucy-pro-i2i"), + "prompt": prompt, + "data": Path(input_image), + } + if seed: + options["seed"] = seed + if strength: + options["strength"] = strength + + result = await client.process(options) + return result, f"Success! Transformed image with prompt: '{prompt[:50]}...'" + except Exception as e: + return None, f"Error: {str(e)}" + + +# ============================================================================ +# Video Processing (Queue API) +# ============================================================================ + + +async def process_video_t2v( + api_key: str, + prompt: str, + seed: Optional[int], + enhance_prompt: bool, + progress=gr.Progress(), +) -> tuple[Optional[str], str]: + """Generate a video from text prompt.""" + try: + client = get_client(api_key) + + options = { + "model": models.video("lucy-pro-t2v"), + "prompt": prompt, + } + if seed: + options["seed"] = seed + if enhance_prompt is not None: + options["enhance_prompt"] = enhance_prompt + + progress(0.1, desc="Submitting job...") + + def on_status_change(job): + if job.status == "pending": + progress(0.2, desc="Job pending...") + elif job.status == "processing": + progress(0.5, desc="Processing video...") + + options["on_status_change"] = on_status_change + + result = await client.queue.submit_and_poll(options) + + if result.status == "failed": + return None, f"Job failed: {result.error}" + + progress(0.9, desc="Saving video...") + + # Save to temp file + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(result.data) + return f.name, f"Success! Generated video from prompt: '{prompt[:50]}...'" + + except Exception as e: + return None, f"Error: {str(e)}" + + +async def process_video_v2v( + api_key: str, + prompt: str, + input_video: str, + seed: Optional[int], + enhance_prompt: bool, + progress=gr.Progress(), +) -> tuple[Optional[str], str]: + """Transform a video with a prompt.""" + try: + if not input_video: + return None, "Please upload a video" + + client = get_client(api_key) + + options = { + "model": models.video("lucy-pro-v2v"), + "prompt": prompt, + "data": Path(input_video), + } + if seed: + options["seed"] = seed + if enhance_prompt is not None: + options["enhance_prompt"] = enhance_prompt + + progress(0.1, desc="Submitting job...") + + def on_status_change(job): + if job.status == "pending": + progress(0.2, desc="Job pending...") + elif job.status == "processing": + progress(0.5, desc="Processing video...") + + options["on_status_change"] = on_status_change + + result = await client.queue.submit_and_poll(options) + + if result.status == "failed": + return None, f"Job failed: {result.error}" + + progress(0.9, desc="Saving video...") + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(result.data) + return f.name, f"Success! Transformed video with prompt: '{prompt[:50]}...'" + + except Exception as e: + return None, f"Error: {str(e)}" + + +async def process_video_restyle( + api_key: str, + input_video: str, + use_reference_image: bool, + prompt: str, + reference_image: Optional[str], + seed: Optional[int], + enhance_prompt: bool, + progress=gr.Progress(), +) -> tuple[Optional[str], str]: + """Restyle a video with prompt OR reference image.""" + try: + if not input_video: + return None, "Please upload a video" + + client = get_client(api_key) + + options = { + "model": models.video("lucy-restyle-v2v"), + "data": Path(input_video), + } + + if use_reference_image: + if not reference_image: + return None, "Please upload a reference image" + options["reference_image"] = Path(reference_image) + else: + if not prompt: + return None, "Please enter a prompt" + options["prompt"] = prompt + if enhance_prompt is not None: + options["enhance_prompt"] = enhance_prompt + + if seed: + options["seed"] = seed + + progress(0.1, desc="Submitting job...") + + def on_status_change(job): + if job.status == "pending": + progress(0.2, desc="Job pending...") + elif job.status == "processing": + progress(0.5, desc="Processing video...") + + options["on_status_change"] = on_status_change + + result = await client.queue.submit_and_poll(options) + + if result.status == "failed": + return None, f"Job failed: {result.error}" + + progress(0.9, desc="Saving video...") + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(result.data) + mode = "reference image" if use_reference_image else f"prompt: '{prompt[:30]}...'" + return f.name, f"Success! Restyled video with {mode}" + + except Exception as e: + return None, f"Error: {str(e)}" + + +async def process_video_i2v( + api_key: str, + prompt: str, + input_image: str, + seed: Optional[int], + enhance_prompt: bool, + progress=gr.Progress(), +) -> tuple[Optional[str], str]: + """Generate a video from an image.""" + try: + if not input_image: + return None, "Please upload an image" + + client = get_client(api_key) + + options = { + "model": models.video("lucy-pro-i2v"), + "prompt": prompt, + "data": Path(input_image), + } + if seed: + options["seed"] = seed + if enhance_prompt is not None: + options["enhance_prompt"] = enhance_prompt + + progress(0.1, desc="Submitting job...") + + def on_status_change(job): + if job.status == "pending": + progress(0.2, desc="Job pending...") + elif job.status == "processing": + progress(0.5, desc="Processing video...") + + options["on_status_change"] = on_status_change + + result = await client.queue.submit_and_poll(options) + + if result.status == "failed": + return None, f"Job failed: {result.error}" + + progress(0.9, desc="Saving video...") + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(result.data) + return f.name, f"Success! Generated video from image" + + except Exception as e: + return None, f"Error: {str(e)}" + + +# ============================================================================ +# Tokens API +# ============================================================================ + + +async def create_token(api_key: str) -> str: + """Create a short-lived client token.""" + try: + client = get_client(api_key) + result = await client.tokens.create() + return f"Success!\n\nToken: {result.api_key}\nExpires: {result.expires_at}" + except Exception as e: + return f"Error: {str(e)}" + + +# ============================================================================ +# Gradio UI +# ============================================================================ + + +def create_ui(): + """Create the Gradio interface.""" + + with gr.Blocks( + title="Decart SDK Test UI", + theme=gr.themes.Soft(), + css=""" + .status-success { color: green; } + .status-error { color: red; } + """, + ) as demo: + gr.Markdown( + """ + # Decart SDK Test UI + + Interactive testing interface for the Decart Python SDK. + Enter your API key below to get started. + """ + ) + + # API Key input (shared across all tabs) + api_key = gr.Textbox( + label="API Key", + placeholder="Enter your Decart API key", + type="password", + elem_id="api-key-input", + ) + + with gr.Tabs(): + # ================================================================ + # Image Processing Tab + # ================================================================ + with gr.TabItem("Image Generation"): + gr.Markdown("### Text to Image") + with gr.Row(): + with gr.Column(): + t2i_prompt = gr.Textbox( + label="Prompt", + placeholder="A beautiful sunset over mountains", + lines=3, + ) + with gr.Row(): + t2i_seed = gr.Number(label="Seed (optional)", precision=0) + t2i_resolution = gr.Dropdown( + label="Resolution", + choices=["default", "720p", "1080p"], + value="default", + ) + t2i_orientation = gr.Dropdown( + label="Orientation", + choices=["default", "landscape", "portrait", "square"], + value="default", + ) + t2i_btn = gr.Button("Generate Image", variant="primary") + with gr.Column(): + t2i_output = gr.Image(label="Generated Image", type="filepath") + t2i_status = gr.Textbox(label="Status", interactive=False) + + t2i_btn.click( + fn=lambda *args: asyncio.run(process_text_to_image(*args)), + inputs=[api_key, t2i_prompt, t2i_seed, t2i_resolution, t2i_orientation], + outputs=[t2i_output, t2i_status], + ) + + gr.Markdown("---") + gr.Markdown("### Image to Image") + with gr.Row(): + with gr.Column(): + i2i_input = gr.Image(label="Input Image", type="filepath") + i2i_prompt = gr.Textbox( + label="Prompt", + placeholder="Make it look like anime", + lines=2, + ) + with gr.Row(): + i2i_seed = gr.Number(label="Seed (optional)", precision=0) + i2i_strength = gr.Slider( + label="Strength", + minimum=0.0, + maximum=1.0, + value=0.75, + step=0.05, + ) + i2i_btn = gr.Button("Transform Image", variant="primary") + with gr.Column(): + i2i_output = gr.Image(label="Transformed Image", type="filepath") + i2i_status = gr.Textbox(label="Status", interactive=False) + + i2i_btn.click( + fn=lambda *args: asyncio.run(process_image_to_image(*args)), + inputs=[api_key, i2i_prompt, i2i_input, i2i_seed, i2i_strength], + outputs=[i2i_output, i2i_status], + ) + + # ================================================================ + # Video Processing Tab + # ================================================================ + with gr.TabItem("Video Generation"): + gr.Markdown("### Text to Video") + with gr.Row(): + with gr.Column(): + t2v_prompt = gr.Textbox( + label="Prompt", + placeholder="A cat walking in a park", + lines=3, + ) + with gr.Row(): + t2v_seed = gr.Number(label="Seed (optional)", precision=0) + t2v_enhance = gr.Checkbox(label="Enhance Prompt", value=True) + t2v_btn = gr.Button("Generate Video", variant="primary") + with gr.Column(): + t2v_output = gr.Video(label="Generated Video") + t2v_status = gr.Textbox(label="Status", interactive=False) + + t2v_btn.click( + fn=lambda *args: asyncio.run(process_video_t2v(*args)), + inputs=[api_key, t2v_prompt, t2v_seed, t2v_enhance], + outputs=[t2v_output, t2v_status], + ) + + gr.Markdown("---") + gr.Markdown("### Image to Video") + with gr.Row(): + with gr.Column(): + i2v_input = gr.Image(label="Input Image", type="filepath") + i2v_prompt = gr.Textbox( + label="Prompt", + placeholder="The scene comes to life", + lines=2, + ) + with gr.Row(): + i2v_seed = gr.Number(label="Seed (optional)", precision=0) + i2v_enhance = gr.Checkbox(label="Enhance Prompt", value=True) + i2v_btn = gr.Button("Generate Video", variant="primary") + with gr.Column(): + i2v_output = gr.Video(label="Generated Video") + i2v_status = gr.Textbox(label="Status", interactive=False) + + i2v_btn.click( + fn=lambda *args: asyncio.run(process_video_i2v(*args)), + inputs=[api_key, i2v_prompt, i2v_input, i2v_seed, i2v_enhance], + outputs=[i2v_output, i2v_status], + ) + + gr.Markdown("---") + gr.Markdown("### Video to Video") + with gr.Row(): + with gr.Column(): + v2v_input = gr.Video(label="Input Video") + v2v_prompt = gr.Textbox( + label="Prompt", + placeholder="Make it look like Lego world", + lines=2, + ) + with gr.Row(): + v2v_seed = gr.Number(label="Seed (optional)", precision=0) + v2v_enhance = gr.Checkbox(label="Enhance Prompt", value=True) + v2v_btn = gr.Button("Transform Video", variant="primary") + with gr.Column(): + v2v_output = gr.Video(label="Transformed Video") + v2v_status = gr.Textbox(label="Status", interactive=False) + + v2v_btn.click( + fn=lambda *args: asyncio.run(process_video_v2v(*args)), + inputs=[api_key, v2v_prompt, v2v_input, v2v_seed, v2v_enhance], + outputs=[v2v_output, v2v_status], + ) + + # ================================================================ + # Video Restyle Tab (NEW - with reference image support) + # ================================================================ + with gr.TabItem("Video Restyle (NEW)"): + gr.Markdown( + """ + ### Video Restyle with Prompt OR Reference Image + + This model supports two modes: + - **Text Prompt**: Describe the style you want + - **Reference Image**: Upload an image to use as style reference + """ + ) + + with gr.Row(): + with gr.Column(): + restyle_input = gr.Video(label="Input Video") + restyle_mode = gr.Checkbox( + label="Use Reference Image (instead of text prompt)", + value=False, + ) + restyle_prompt = gr.Textbox( + label="Prompt", + placeholder="Make it look like anime", + lines=2, + visible=True, + ) + restyle_ref_image = gr.Image( + label="Reference Image", + type="filepath", + visible=False, + ) + with gr.Row(): + restyle_seed = gr.Number(label="Seed (optional)", precision=0) + restyle_enhance = gr.Checkbox( + label="Enhance Prompt", + value=True, + visible=True, + ) + restyle_btn = gr.Button("Restyle Video", variant="primary") + with gr.Column(): + restyle_output = gr.Video(label="Restyled Video") + restyle_status = gr.Textbox(label="Status", interactive=False) + + # Toggle visibility based on mode + def toggle_mode(use_ref): + return ( + gr.update(visible=not use_ref), # prompt + gr.update(visible=use_ref), # ref image + gr.update(visible=not use_ref), # enhance + ) + + restyle_mode.change( + fn=toggle_mode, + inputs=[restyle_mode], + outputs=[restyle_prompt, restyle_ref_image, restyle_enhance], + ) + + restyle_btn.click( + fn=lambda *args: asyncio.run(process_video_restyle(*args)), + inputs=[ + api_key, + restyle_input, + restyle_mode, + restyle_prompt, + restyle_ref_image, + restyle_seed, + restyle_enhance, + ], + outputs=[restyle_output, restyle_status], + ) + + # ================================================================ + # Tokens Tab + # ================================================================ + with gr.TabItem("Tokens"): + gr.Markdown( + """ + ### Create Client Token + + Create a short-lived token for client-side use. + These tokens are meant for temporary access and expire automatically. + """ + ) + + with gr.Row(): + with gr.Column(): + token_btn = gr.Button("Create Token", variant="primary") + with gr.Column(): + token_output = gr.Textbox( + label="Result", + lines=5, + interactive=False, + ) + + token_btn.click( + fn=lambda key: asyncio.run(create_token(key)), + inputs=[api_key], + outputs=[token_output], + ) + + gr.Markdown( + """ + --- + **Note**: This UI uses the Decart Python SDK. + For realtime/WebRTC features, use the example scripts in `examples/`. + """ + ) + + return demo + + +if __name__ == "__main__": + demo = create_ui() + demo.launch( + server_name="127.0.0.1", # localhost only + server_port=7860, + share=False, + ) diff --git a/tests/test_models.py b/tests/test_models.py index 702145f..0d0e0ef 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,6 +17,14 @@ def test_realtime_models() -> None: assert model.height == 704 assert model.url_path == "/v1/stream" + # avatar-live model + model = models.realtime("avatar-live") + assert model.name == "avatar-live" + assert model.fps == 25 + assert model.width == 1280 + assert model.height == 720 + assert model.url_path == "/v1/avatar-live/stream" + def test_video_models() -> None: model = models.video("lucy-pro-t2v") @@ -26,6 +34,11 @@ def test_video_models() -> None: model = models.video("lucy-pro-v2v") assert model.name == "lucy-pro-v2v" + # lucy-restyle-v2v model + model = models.video("lucy-restyle-v2v") + assert model.name == "lucy-restyle-v2v" + assert model.url_path == "/v1/generate/lucy-restyle-v2v" + def test_image_models() -> None: model = models.image("lucy-pro-t2i") diff --git a/tests/test_queue.py b/tests/test_queue.py index ad18fc5..3d287ee 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -261,3 +261,101 @@ async def test_queue_includes_user_agent_header() -> None: assert "User-Agent" in headers assert headers["User-Agent"].startswith("decart-python-sdk/") + + +# Tests for lucy-restyle-v2v with reference_image + + +@pytest.mark.asyncio +async def test_queue_restyle_with_prompt() -> None: + """Test lucy-restyle-v2v submission with text prompt.""" + client = DecartClient(api_key="test-key") + + with patch("decart.queue.client.submit_job") as mock_submit: + mock_submit.return_value = MagicMock(job_id="job-789", status="pending") + + job = await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "prompt": "Make it look like anime", + "data": b"fake video data", + "enhance_prompt": True, + } + ) + + assert job.job_id == "job-789" + assert job.status == "pending" + mock_submit.assert_called_once() + + +@pytest.mark.asyncio +async def test_queue_restyle_with_reference_image() -> None: + """Test lucy-restyle-v2v submission with reference image.""" + client = DecartClient(api_key="test-key") + + with patch("decart.queue.client.submit_job") as mock_submit: + mock_submit.return_value = MagicMock(job_id="job-890", status="pending") + + job = await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "reference_image": b"fake image data", + "data": b"fake video data", + } + ) + + assert job.job_id == "job-890" + assert job.status == "pending" + mock_submit.assert_called_once() + + +@pytest.mark.asyncio +async def test_queue_restyle_rejects_both_prompt_and_reference_image() -> None: + """Test that lucy-restyle-v2v rejects both prompt and reference_image.""" + client = DecartClient(api_key="test-key") + + with pytest.raises(DecartSDKError) as exc_info: + await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "prompt": "Make it anime", + "reference_image": b"fake image data", + "data": b"fake video data", + } + ) + + assert "either 'prompt' or 'reference_image'" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_queue_restyle_rejects_neither_prompt_nor_reference_image() -> None: + """Test that lucy-restyle-v2v rejects when neither prompt nor reference_image provided.""" + client = DecartClient(api_key="test-key") + + with pytest.raises(DecartSDKError) as exc_info: + await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "data": b"fake video data", + } + ) + + assert "either 'prompt' or 'reference_image'" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_queue_restyle_rejects_enhance_prompt_with_reference_image() -> None: + """Test that enhance_prompt is only valid with text prompt, not reference_image.""" + client = DecartClient(api_key="test-key") + + with pytest.raises(DecartSDKError) as exc_info: + await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "reference_image": b"fake image data", + "data": b"fake video data", + "enhance_prompt": True, + } + ) + + assert "enhance_prompt" in str(exc_info.value).lower() diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index fee649d..15882ff 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -278,3 +278,319 @@ async def set_event(): assert "Server rejected prompt" in str(exc_info.value) mock_manager.unregister_prompt_wait.assert_called_with("New prompt") + + +# Tests for avatar-live model + + +def test_avatar_live_model_available(): + """Test that avatar-live model is available""" + model = models.realtime("avatar-live") + assert model.name == "avatar-live" + assert model.fps == 25 + assert model.width == 1280 + assert model.height == 720 + assert model.url_path == "/v1/avatar-live/stream" + + +@pytest.mark.asyncio +async def test_avatar_live_connect_with_avatar_image(): + """Test avatar-live connection with avatar image option""" + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.is_connected = MagicMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"fake image data", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions, AvatarOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("avatar-live"), + on_remote_stream=lambda t: None, + avatar=AvatarOptions(avatar_image=b"fake image bytes"), + ), + ) + + assert realtime_client is not None + assert realtime_client._is_avatar_live is True + mock_file_input.assert_called_once() + # Verify avatar_image_base64 was passed to connect + mock_manager.connect.assert_called_once() + call_kwargs = mock_manager.connect.call_args[1] + assert "avatar_image_base64" in call_kwargs + assert call_kwargs["avatar_image_base64"] is not None + + +@pytest.mark.asyncio +async def test_avatar_live_set_image(): + """Test set_image method for avatar-live""" + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + image_set_event = asyncio.Event() + image_set_result = {"success": True, "error": None} + + mock_manager.register_image_set_wait = MagicMock( + return_value=(image_set_event, image_set_result) + ) + mock_manager.unregister_image_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"new image data", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("avatar-live"), + on_remote_stream=lambda t: None, + ), + ) + + async def set_event(): + await asyncio.sleep(0.01) + image_set_event.set() + + asyncio.create_task(set_event()) + await realtime_client.set_image(b"new avatar image") + + mock_manager.send_message.assert_called() + call_args = mock_manager.send_message.call_args[0][0] + assert call_args.type == "set_image" + assert call_args.image_data is not None + mock_manager.unregister_image_set_wait.assert_called_once() + + +@pytest.mark.asyncio +async def test_set_image_only_for_avatar_live(): + """Test that set_image raises error for non-avatar-live models""" + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + from decart.errors import InvalidInputError + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("mirage"), # Not avatar-live + on_remote_stream=lambda t: None, + ), + ) + + with pytest.raises(InvalidInputError) as exc_info: + await realtime_client.set_image(b"test image") + + assert "avatar-live" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_avatar_live_set_image_timeout(): + """Test set_image raises on timeout""" + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + image_set_event = asyncio.Event() + image_set_result = {"success": False, "error": None} + + mock_manager.register_image_set_wait = MagicMock( + return_value=(image_set_event, image_set_result) + ) + mock_manager.unregister_image_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"image data", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + from decart.errors import DecartSDKError + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("avatar-live"), + on_remote_stream=lambda t: None, + ), + ) + + with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError): + with pytest.raises(DecartSDKError) as exc_info: + await realtime_client.set_image(b"test image") + + assert "timed out" in str(exc_info.value).lower() + mock_manager.unregister_image_set_wait.assert_called_once() + + +@pytest.mark.asyncio +async def test_avatar_live_set_image_server_error(): + """Test set_image raises on server error""" + import asyncio + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + image_set_event = asyncio.Event() + image_set_result = {"success": False, "error": "Invalid image format"} + + mock_manager.register_image_set_wait = MagicMock( + return_value=(image_set_event, image_set_result) + ) + mock_manager.unregister_image_set_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"image data", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + from decart.errors import DecartSDKError + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("avatar-live"), + on_remote_stream=lambda t: None, + ), + ) + + async def set_event(): + await asyncio.sleep(0.01) + image_set_event.set() + + asyncio.create_task(set_event()) + + with pytest.raises(DecartSDKError) as exc_info: + await realtime_client.set_image(b"test image") + + assert "Invalid image format" in str(exc_info.value) + mock_manager.unregister_image_set_wait.assert_called_once() + + +@pytest.mark.asyncio +async def test_connect_with_initial_prompt(): + """Test connection with initial_prompt option""" + + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.is_connected = MagicMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions, InitialPromptOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("mirage"), + on_remote_stream=lambda t: None, + initial_prompt=InitialPromptOptions(text="Test prompt", enhance=False), + ), + ) + + assert realtime_client is not None + mock_manager.connect.assert_called_once() + call_kwargs = mock_manager.connect.call_args[1] + assert "initial_prompt" in call_kwargs + assert call_kwargs["initial_prompt"] == {"text": "Test prompt", "enhance": False}