Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 73 additions & 35 deletions quickstarts/Get_started_LiveAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
asyncio.TaskGroup = taskgroup.TaskGroup
asyncio.ExceptionGroup = exceptiongroup.ExceptionGroup

# ============================================================================
# Configuration
# ============================================================================

FORMAT = pyaudio.paInt16
CHANNELS = 1
SEND_SAMPLE_RATE = 16000
Expand All @@ -84,19 +88,28 @@
pya = pyaudio.PyAudio()


class AudioLoop:
# ============================================================================
# Main Application Class
# ============================================================================

class AudioVideoLoop:
def __init__(self, video_mode=DEFAULT_MODE):
self.video_mode = video_mode

self.audio_in_queue = None
self.out_queue = None

self.session = None
self.audio_stream = None

self.send_text_task = None
self.receive_audio_task = None
self.play_audio_task = None

# ========================================================================
# Text Input
# ========================================================================

async def send_text(self):
while True:
text = await asyncio.to_thread(
Expand All @@ -107,8 +120,12 @@ async def send_text(self):
break
await self.session.send(input=text or ".", end_of_turn=True)

def _get_frame(self, cap):
# Read the frameq
# ========================================================================
# Video Capture
# ========================================================================

def _capture_frame(self, cap):
# Read the frame from camera
ret, frame = cap.read()
# Check if the frame was read successfully
if not ret:
Expand All @@ -128,26 +145,27 @@ def _get_frame(self, cap):
image_bytes = image_io.read()
return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()}

async def get_frames(self):
async def stream_camera_frames(self):
# This takes about a second, and will block the whole program
# causing the audio pipeline to overflow if you don't to_thread it.
cap = await asyncio.to_thread(
cv2.VideoCapture, 0
) # 0 represents the default camera

while True:
frame = await asyncio.to_thread(self._get_frame, cap)
if frame is None:
break

await asyncio.sleep(1.0)
try:
while True:
frame = await asyncio.to_thread(self._capture_frame, cap)
if frame is None:
break

await self.out_queue.put(frame)
await asyncio.sleep(1.0)

# Release the VideoCapture object
cap.release()
await self.out_queue.put(frame)
finally:
# Release the VideoCapture object
cap.release()

def _get_screen(self):
def _capture_screen(self):
sct = mss.mss()
monitor = sct.monitors[0]

Expand All @@ -164,22 +182,29 @@ def _get_screen(self):
image_bytes = image_io.read()
return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()}

async def get_screen(self):

async def stream_screen(self):
while True:
frame = await asyncio.to_thread(self._get_screen)
frame = await asyncio.to_thread(self._capture_screen)
if frame is None:
break

await asyncio.sleep(1.0)

await self.out_queue.put(frame)

# ========================================================================
# Real-time Media Stream
# ========================================================================

async def send_realtime(self):
while True:
msg = await self.out_queue.get()
await self.session.send(input=msg)

# ========================================================================
# Audio Input/Output
# ========================================================================

async def listen_audio(self):
mic_info = pya.get_default_input_device_info()
self.audio_stream = await asyncio.to_thread(
Expand All @@ -191,16 +216,20 @@ async def listen_audio(self):
input_device_index=mic_info["index"],
frames_per_buffer=CHUNK_SIZE,
)
if __debug__:
kwargs = {"exception_on_overflow": False}
else:
kwargs = {}
while True:
data = await asyncio.to_thread(self.audio_stream.read, CHUNK_SIZE, **kwargs)
await self.out_queue.put({"data": data, "mime_type": "audio/pcm"})
try:
if __debug__:
kwargs = {"exception_on_overflow": False}
else:
kwargs = {}
while True:
data = await asyncio.to_thread(self.audio_stream.read, CHUNK_SIZE, **kwargs)
await self.out_queue.put({"data": data, "mime_type": "audio/pcm"})
finally:
# Ensure audio stream is closed on exit
self.audio_stream.close()

async def receive_audio(self):
"Background task to reads from the websocket and write pcm chunks to the output queue"
"""Background task to read from the websocket and queue audio chunks for playback."""
while True:
turn = self.session.receive()
async for response in turn:
Expand All @@ -225,11 +254,20 @@ async def play_audio(self):
rate=RECEIVE_SAMPLE_RATE,
output=True,
)
while True:
bytestream = await self.audio_in_queue.get()
await asyncio.to_thread(stream.write, bytestream)
try:
while True:
bytestream = await self.audio_in_queue.get()
await asyncio.to_thread(stream.write, bytestream)
finally:
# Ensure output stream is closed on exit
stream.close()

# ========================================================================
# Main Event Loop
# ========================================================================

async def run(self):
"""Main application loop managing all async tasks."""
try:
async with (
client.aio.live.connect(model=MODEL, config=CONFIG) as session,
Expand All @@ -244,9 +282,9 @@ async def run(self):
tg.create_task(self.send_realtime())
tg.create_task(self.listen_audio())
if self.video_mode == "camera":
tg.create_task(self.get_frames())
tg.create_task(self.stream_camera_frames())
elif self.video_mode == "screen":
tg.create_task(self.get_screen())
tg.create_task(self.stream_screen())

tg.create_task(self.receive_audio())
tg.create_task(self.play_audio())
Expand All @@ -256,9 +294,9 @@ async def run(self):

except asyncio.CancelledError:
pass
except ExceptionGroup as EG:
self.audio_stream.close()
traceback.print_exception(EG)
except ExceptionGroup as eg:
traceback.print_exc()



if __name__ == "__main__":
Expand All @@ -271,5 +309,5 @@ async def run(self):
choices=["camera", "screen", "none"],
)
args = parser.parse_args()
main = AudioLoop(video_mode=args.mode)
asyncio.run(main.run())
app = AudioVideoLoop(video_mode=args.mode)
asyncio.run(app.run())