diff --git a/quickstarts/Get_started_LiveAPI.py b/quickstarts/Get_started_LiveAPI.py index ce6f7952d..91f274c37 100755 --- a/quickstarts/Get_started_LiveAPI.py +++ b/quickstarts/Get_started_LiveAPI.py @@ -67,6 +67,10 @@ asyncio.TaskGroup = taskgroup.TaskGroup asyncio.ExceptionGroup = exceptiongroup.ExceptionGroup +# ============================================================================ +# Configuration +# ============================================================================ + FORMAT = pyaudio.paInt16 CHANNELS = 1 SEND_SAMPLE_RATE = 16000 @@ -84,7 +88,11 @@ pya = pyaudio.PyAudio() -class AudioLoop: +# ============================================================================ +# Main Application Class +# ============================================================================ + +class AudioVideoLoop: def __init__(self, video_mode=DEFAULT_MODE): self.video_mode = video_mode @@ -92,11 +100,16 @@ def __init__(self, video_mode=DEFAULT_MODE): self.out_queue = None self.session = None + self.audio_stream = None self.send_text_task = None self.receive_audio_task = None self.play_audio_task = None + # ======================================================================== + # Text Input + # ======================================================================== + async def send_text(self): while True: text = await asyncio.to_thread( @@ -107,8 +120,12 @@ async def send_text(self): break await self.session.send(input=text or ".", end_of_turn=True) - def _get_frame(self, cap): - # Read the frameq + # ======================================================================== + # Video Capture + # ======================================================================== + + def _capture_frame(self, cap): + # Read the frame from camera ret, frame = cap.read() # Check if the frame was read successfully if not ret: @@ -128,26 +145,27 @@ def _get_frame(self, cap): image_bytes = image_io.read() return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} - async def get_frames(self): + async def stream_camera_frames(self): # This takes about a second, and will block the whole program # causing the audio pipeline to overflow if you don't to_thread it. cap = await asyncio.to_thread( cv2.VideoCapture, 0 ) # 0 represents the default camera - while True: - frame = await asyncio.to_thread(self._get_frame, cap) - if frame is None: - break - - await asyncio.sleep(1.0) + try: + while True: + frame = await asyncio.to_thread(self._capture_frame, cap) + if frame is None: + break - await self.out_queue.put(frame) + await asyncio.sleep(1.0) - # Release the VideoCapture object - cap.release() + await self.out_queue.put(frame) + finally: + # Release the VideoCapture object + cap.release() - def _get_screen(self): + def _capture_screen(self): sct = mss.mss() monitor = sct.monitors[0] @@ -164,10 +182,9 @@ def _get_screen(self): image_bytes = image_io.read() return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} - async def get_screen(self): - + async def stream_screen(self): while True: - frame = await asyncio.to_thread(self._get_screen) + frame = await asyncio.to_thread(self._capture_screen) if frame is None: break @@ -175,11 +192,19 @@ async def get_screen(self): await self.out_queue.put(frame) + # ======================================================================== + # Real-time Media Stream + # ======================================================================== + async def send_realtime(self): while True: msg = await self.out_queue.get() await self.session.send(input=msg) + # ======================================================================== + # Audio Input/Output + # ======================================================================== + async def listen_audio(self): mic_info = pya.get_default_input_device_info() self.audio_stream = await asyncio.to_thread( @@ -191,16 +216,20 @@ async def listen_audio(self): input_device_index=mic_info["index"], frames_per_buffer=CHUNK_SIZE, ) - if __debug__: - kwargs = {"exception_on_overflow": False} - else: - kwargs = {} - while True: - data = await asyncio.to_thread(self.audio_stream.read, CHUNK_SIZE, **kwargs) - await self.out_queue.put({"data": data, "mime_type": "audio/pcm"}) + try: + if __debug__: + kwargs = {"exception_on_overflow": False} + else: + kwargs = {} + while True: + data = await asyncio.to_thread(self.audio_stream.read, CHUNK_SIZE, **kwargs) + await self.out_queue.put({"data": data, "mime_type": "audio/pcm"}) + finally: + # Ensure audio stream is closed on exit + self.audio_stream.close() async def receive_audio(self): - "Background task to reads from the websocket and write pcm chunks to the output queue" + """Background task to read from the websocket and queue audio chunks for playback.""" while True: turn = self.session.receive() async for response in turn: @@ -225,11 +254,20 @@ async def play_audio(self): rate=RECEIVE_SAMPLE_RATE, output=True, ) - while True: - bytestream = await self.audio_in_queue.get() - await asyncio.to_thread(stream.write, bytestream) + try: + while True: + bytestream = await self.audio_in_queue.get() + await asyncio.to_thread(stream.write, bytestream) + finally: + # Ensure output stream is closed on exit + stream.close() + + # ======================================================================== + # Main Event Loop + # ======================================================================== async def run(self): + """Main application loop managing all async tasks.""" try: async with ( client.aio.live.connect(model=MODEL, config=CONFIG) as session, @@ -244,9 +282,9 @@ async def run(self): tg.create_task(self.send_realtime()) tg.create_task(self.listen_audio()) if self.video_mode == "camera": - tg.create_task(self.get_frames()) + tg.create_task(self.stream_camera_frames()) elif self.video_mode == "screen": - tg.create_task(self.get_screen()) + tg.create_task(self.stream_screen()) tg.create_task(self.receive_audio()) tg.create_task(self.play_audio()) @@ -256,9 +294,9 @@ async def run(self): except asyncio.CancelledError: pass - except ExceptionGroup as EG: - self.audio_stream.close() - traceback.print_exception(EG) + except ExceptionGroup as eg: + traceback.print_exc() + if __name__ == "__main__": @@ -271,5 +309,5 @@ async def run(self): choices=["camera", "screen", "none"], ) args = parser.parse_args() - main = AudioLoop(video_mode=args.mode) - asyncio.run(main.run()) + app = AudioVideoLoop(video_mode=args.mode) + asyncio.run(app.run())