-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Refactor LiveAPI sample for cleaner async flow and robust audio/video handling #1073
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -51,64 +51,153 @@ | |||||
| import os | ||||||
| import sys | ||||||
| import traceback | ||||||
| import argparse | ||||||
|
|
||||||
| import cv2 | ||||||
| import pyaudio | ||||||
| import PIL.Image | ||||||
| import mss | ||||||
|
|
||||||
| import argparse | ||||||
|
|
||||||
| from google import genai | ||||||
| from google.genai import types | ||||||
|
|
||||||
| if sys.version_info < (3, 11, 0): | ||||||
| import taskgroup, exceptiongroup | ||||||
|
|
||||||
| asyncio.TaskGroup = taskgroup.TaskGroup | ||||||
| asyncio.ExceptionGroup = exceptiongroup.ExceptionGroup | ||||||
|
|
||||||
| # --- Audio Configuration --- | ||||||
| FORMAT = pyaudio.paInt16 | ||||||
| CHANNELS = 1 | ||||||
| SEND_SAMPLE_RATE = 16000 | ||||||
| RECEIVE_SAMPLE_RATE = 24000 | ||||||
| CHUNK_SIZE = 1024 | ||||||
|
|
||||||
| MODEL = "gemini-2.5-flash-native-audio-preview-12-2025" | ||||||
|
|
||||||
| # --- Model Configuration --- | ||||||
| MODEL = "models/gemini-2.5-flash-native-audio-preview-12-2025" | ||||||
| DEFAULT_MODE = "camera" | ||||||
|
|
||||||
| client = genai.Client(http_options={"api_version": "v1beta"}) | ||||||
|
|
||||||
| CONFIG = {"response_modalities": ["AUDIO"]} | ||||||
| client = genai.Client( | ||||||
| api_key = os.environ.get("GEMINI_API_KEY"), | ||||||
| http_options={"api_version": "v1beta"}, | ||||||
| ) | ||||||
|
|
||||||
| # Live session configuration | ||||||
| # Trigger tokens sent so that model does not hallucinate in long conversations | ||||||
| # Sliding window to retain the context within the context window limit | ||||||
| CONFIG = types.LiveConnectConfig( | ||||||
| response_modalities=["AUDIO"], | ||||||
| speech_config=types.SpeechConfig( | ||||||
| voice_config=types.VoiceConfig( | ||||||
| prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name = "Zephyr") | ||||||
| ) | ||||||
| ), | ||||||
| context_window_compression=types.ContextWindowCompressionConfig( | ||||||
| trigger_tokens = 25600, | ||||||
| sliding_window = types.SlidingWindow(target_tokens=12800), | ||||||
| ), | ||||||
| ) | ||||||
|
|
||||||
| pya = pyaudio.PyAudio() | ||||||
|
|
||||||
|
|
||||||
| class AudioLoop: | ||||||
| class AudioVideoLoop: | ||||||
| def __init__(self, video_mode=DEFAULT_MODE): | ||||||
| self.video_mode = video_mode | ||||||
|
|
||||||
| self.audio_in_queue = None | ||||||
| self.out_queue = None | ||||||
| self.audio_in_queue = asyncio.Queue() | ||||||
| self.out_queue = asyncio.Queue(maxsize = 5) # Limit size to avoid excess memory use | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor style fix: according to PEP 8, there should be no space around the
Suggested change
References
|
||||||
|
|
||||||
| self.session = None | ||||||
| self.audio_stream = None | ||||||
|
|
||||||
| self.send_text_task = None | ||||||
| self.receive_audio_task = None | ||||||
| self.play_audio_task = None | ||||||
| # --- Audio Handling --- | ||||||
|
|
||||||
| async def send_text(self): | ||||||
| while True: | ||||||
| text = await asyncio.to_thread( | ||||||
| input, | ||||||
| "message > ", | ||||||
| ) | ||||||
| if text.lower() == "q": | ||||||
| break | ||||||
| await self.session.send(input=text or ".", end_of_turn=True) | ||||||
|
|
||||||
| def _get_frame(self, cap): | ||||||
| # Read the frameq | ||||||
| async def listen_audio(self): | ||||||
| mic_info = pya.get_default_input_device_info() | ||||||
| self.audio_stream = await asyncio.to_thread( | ||||||
| pya.open, | ||||||
| format=FORMAT, | ||||||
| channels=CHANNELS, | ||||||
| rate=SEND_SAMPLE_RATE, | ||||||
| input=True, | ||||||
| input_device_index=mic_info["index"], | ||||||
| frames_per_buffer=CHUNK_SIZE, | ||||||
| ) | ||||||
| if __debug__: | ||||||
| kwargs = {"exception_on_overflow": False} | ||||||
| else: | ||||||
| kwargs = {} | ||||||
|
|
||||||
| try: | ||||||
| while True: | ||||||
| data = await asyncio.to_thread(self.audio_stream.read, CHUNK_SIZE, **kwargs) | ||||||
| payload = { | ||||||
| "data": data, | ||||||
| "mime_type": "audio/pcm" | ||||||
| } | ||||||
| # To reduce latency instead of watiing to push in queue we pop oldest item in queue if its full | ||||||
| # This helps to keep the audio stream real time | ||||||
| try: | ||||||
| self.out_queue.put_nowait(payload) | ||||||
| except asyncio.QueueFull: | ||||||
| _ = self.out_queue.get_nowait() | ||||||
| self.out_queue.put_nowait(payload) | ||||||
|
|
||||||
| except asyncio.CancelledError: | ||||||
| pass | ||||||
| finally: | ||||||
| if self.audio_stream: | ||||||
| self.audio_stream.stop_stream() | ||||||
| self.audio_stream.close() | ||||||
|
|
||||||
| async def play_audio(self): | ||||||
| stream = await asyncio.to_thread( | ||||||
| pya.open, | ||||||
| format=FORMAT, | ||||||
| channels=CHANNELS, | ||||||
| rate=RECEIVE_SAMPLE_RATE, | ||||||
| output=True, | ||||||
| ) | ||||||
| try: | ||||||
| while True: | ||||||
| bytestream = await self.audio_in_queue.get() | ||||||
| await asyncio.to_thread(stream.write, bytestream) | ||||||
| except asyncio.CancelledError: | ||||||
| pass | ||||||
| finally: | ||||||
| if stream: | ||||||
| stream.stop_stream() | ||||||
| stream.close() | ||||||
|
|
||||||
| async def receive_audio(self): | ||||||
| """Read from the websocket and write PCM chunks to the output queue.""" | ||||||
| try: | ||||||
| while True: | ||||||
| turn = self.session.receive() | ||||||
| async for response in turn: | ||||||
| if data := response.data: | ||||||
| self.audio_in_queue.put_nowait(data) | ||||||
| continue | ||||||
| if text := response.text: | ||||||
| print(text, end="") | ||||||
|
|
||||||
| # If you interrupt the model, it sends a turn_complete. | ||||||
| # For interruptions to work, we need to stop playback. | ||||||
| # So empty out the audio queue because it may have loaded | ||||||
| # much more audio than has played yet. | ||||||
| while not self.audio_in_queue.empty(): | ||||||
| self.audio_in_queue.get_nowait() | ||||||
| except asyncio.CancelledError: | ||||||
| pass | ||||||
|
|
||||||
| # --- Video Handling --- | ||||||
|
|
||||||
| def _capture_frame(self, cap): | ||||||
| """Capture frame from camera and convert to base64 JPEG.""" | ||||||
| # Read the frame | ||||||
| ret, frame = cap.read() | ||||||
| # Check if the frame was read successfully | ||||||
| if not ret: | ||||||
|
|
@@ -117,7 +206,7 @@ def _get_frame(self, cap): | |||||
| # OpenCV captures in BGR but PIL expects RGB format | ||||||
| # This prevents the blue tint in the video feed | ||||||
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||||||
| img = PIL.Image.fromarray(frame_rgb) # Now using RGB frame | ||||||
| img = PIL.Image.fromarray(frame_rgb) | ||||||
| img.thumbnail([1024, 1024]) | ||||||
|
|
||||||
| image_io = io.BytesIO() | ||||||
|
|
@@ -128,125 +217,97 @@ def _get_frame(self, cap): | |||||
| image_bytes = image_io.read() | ||||||
| return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} | ||||||
|
|
||||||
| async def get_frames(self): | ||||||
| # This takes about a second, and will block the whole program | ||||||
| # causing the audio pipeline to overflow if you don't to_thread it. | ||||||
| async def capture_frames(self): | ||||||
| cap = await asyncio.to_thread( | ||||||
| cv2.VideoCapture, 0 | ||||||
| ) # 0 represents the default camera | ||||||
|
|
||||||
| while True: | ||||||
| frame = await asyncio.to_thread(self._get_frame, cap) | ||||||
| if frame is None: | ||||||
| break | ||||||
|
|
||||||
| await asyncio.sleep(1.0) | ||||||
|
|
||||||
| await self.out_queue.put(frame) | ||||||
| try: | ||||||
| while True: | ||||||
| frame = await asyncio.to_thread(self._capture_frame, cap) | ||||||
| if frame is None: | ||||||
| break | ||||||
|
|
||||||
| # Release the VideoCapture object | ||||||
| cap.release() | ||||||
| await asyncio.sleep(1.0) | ||||||
| await self.out_queue.put(frame) | ||||||
| except asyncio.CancelledError: | ||||||
| pass | ||||||
| finally: | ||||||
| cap.release() | ||||||
|
|
||||||
| def _get_screen(self): | ||||||
| def _capture_screen(self): | ||||||
| sct = mss.mss() | ||||||
| monitor = sct.monitors[0] | ||||||
|
|
||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a few unnecessary blank lines in the code (here and on line 306) that could be removed to improve readability and compactness. References
|
||||||
| i = sct.grab(monitor) | ||||||
|
|
||||||
| mime_type = "image/jpeg" | ||||||
| image_bytes = mss.tools.to_png(i.rgb, i.size) | ||||||
| img = PIL.Image.open(io.BytesIO(image_bytes)) | ||||||
|
|
||||||
| img = PIL.Image.frombytes("RGB", i.size, i.rgb) | ||||||
|
|
||||||
| image_io = io.BytesIO() | ||||||
| img.save(image_io, format="jpeg") | ||||||
| image_io.seek(0) | ||||||
|
|
||||||
| mime_type = "image/jpeg" | ||||||
| image_bytes = image_io.read() | ||||||
| return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} | ||||||
|
|
||||||
| async def get_screen(self): | ||||||
| async def capture_screen(self): | ||||||
| try: | ||||||
| while True: | ||||||
| frame = await asyncio.to_thread(self._capture_screen) | ||||||
| if frame is None: | ||||||
| break | ||||||
|
|
||||||
| while True: | ||||||
| frame = await asyncio.to_thread(self._get_screen) | ||||||
| if frame is None: | ||||||
| break | ||||||
| await asyncio.sleep(1.0) | ||||||
| await self.out_queue.put(frame) | ||||||
| except asyncio.CancelledError: | ||||||
| pass | ||||||
|
|
||||||
| await asyncio.sleep(1.0) | ||||||
| # --- Text & Main Loop --- | ||||||
|
|
||||||
| await self.out_queue.put(frame) | ||||||
| async def send_text(self): | ||||||
| try: | ||||||
| while True: | ||||||
| text = await asyncio.to_thread( | ||||||
| input, | ||||||
| "message > ", | ||||||
| ) | ||||||
| if text.lower() == "q": | ||||||
| print("👋 Exiting on user request...") | ||||||
| break | ||||||
| await self.session.send(input=text or ".", end_of_turn=True) | ||||||
| except asyncio.CancelledError: | ||||||
| pass | ||||||
|
|
||||||
| async def send_realtime(self): | ||||||
| while True: | ||||||
| msg = await self.out_queue.get() | ||||||
| await self.session.send(input=msg) | ||||||
|
|
||||||
| async def listen_audio(self): | ||||||
| mic_info = pya.get_default_input_device_info() | ||||||
| self.audio_stream = await asyncio.to_thread( | ||||||
| pya.open, | ||||||
| format=FORMAT, | ||||||
| channels=CHANNELS, | ||||||
| rate=SEND_SAMPLE_RATE, | ||||||
| input=True, | ||||||
| input_device_index=mic_info["index"], | ||||||
| frames_per_buffer=CHUNK_SIZE, | ||||||
| ) | ||||||
| if __debug__: | ||||||
| kwargs = {"exception_on_overflow": False} | ||||||
| else: | ||||||
| kwargs = {} | ||||||
| while True: | ||||||
| data = await asyncio.to_thread(self.audio_stream.read, CHUNK_SIZE, **kwargs) | ||||||
| await self.out_queue.put({"data": data, "mime_type": "audio/pcm"}) | ||||||
|
|
||||||
| async def receive_audio(self): | ||||||
| "Background task to reads from the websocket and write pcm chunks to the output queue" | ||||||
| while True: | ||||||
| turn = self.session.receive() | ||||||
| async for response in turn: | ||||||
| if data := response.data: | ||||||
| self.audio_in_queue.put_nowait(data) | ||||||
| continue | ||||||
| if text := response.text: | ||||||
| print(text, end="") | ||||||
|
|
||||||
| # If you interrupt the model, it sends a turn_complete. | ||||||
| # For interruptions to work, we need to stop playback. | ||||||
| # So empty out the audio queue because it may have loaded | ||||||
| # much more audio than has played yet. | ||||||
| while not self.audio_in_queue.empty(): | ||||||
| self.audio_in_queue.get_nowait() | ||||||
|
|
||||||
| async def play_audio(self): | ||||||
| stream = await asyncio.to_thread( | ||||||
| pya.open, | ||||||
| format=FORMAT, | ||||||
| channels=CHANNELS, | ||||||
| rate=RECEIVE_SAMPLE_RATE, | ||||||
| output=True, | ||||||
| ) | ||||||
| while True: | ||||||
| bytestream = await self.audio_in_queue.get() | ||||||
| await asyncio.to_thread(stream.write, bytestream) | ||||||
| try: | ||||||
| while True: | ||||||
| msg = await self.out_queue.get() | ||||||
| await self.session.send(input=msg) | ||||||
| except asyncio.CancelledError: | ||||||
| pass | ||||||
|
|
||||||
| async def run(self): | ||||||
| """Run all tasks to handle audio/video/text interaction""" | ||||||
| try: | ||||||
| async with ( | ||||||
| client.aio.live.connect(model=MODEL, config=CONFIG) as session, | ||||||
| asyncio.TaskGroup() as tg, | ||||||
| ): | ||||||
| self.session = session | ||||||
|
|
||||||
| # Re-initialize queue for fresh session | ||||||
| self.audio_in_queue = asyncio.Queue() | ||||||
| self.out_queue = asyncio.Queue(maxsize=5) | ||||||
|
|
||||||
| send_text_task = tg.create_task(self.send_text()) | ||||||
| tg.create_task(self.send_realtime()) | ||||||
| tg.create_task(self.listen_audio()) | ||||||
|
|
||||||
| if self.video_mode == "camera": | ||||||
| tg.create_task(self.get_frames()) | ||||||
| tg.create_task(self.capture_frames()) | ||||||
| elif self.video_mode == "screen": | ||||||
| tg.create_task(self.get_screen()) | ||||||
| tg.create_task(self.capture_screen()) | ||||||
|
|
||||||
| tg.create_task(self.receive_audio()) | ||||||
| tg.create_task(self.play_audio()) | ||||||
|
|
@@ -271,5 +332,5 @@ async def run(self): | |||||
| choices=["camera", "screen", "none"], | ||||||
| ) | ||||||
| args = parser.parse_args() | ||||||
| main = AudioLoop(video_mode=args.mode) | ||||||
| main = AudioVideoLoop(video_mode=args.mode) | ||||||
| asyncio.run(main.run()) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The repository's style guide defers to the Google Python Style Guide, which specifies that there should be no spaces around the
=sign for keyword arguments. This convention is also enforced by tools likepyink. I've noticed this pattern in a few other places in the file (e.g., lines 94, 98, 99, 111). It would be great to fix them for consistency.References
=sign for keyword arguments or default parameter values. (link)