Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 176 additions & 14 deletions getstream/video/rtc/track_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,34 @@
import base64
import fractions
import io
import logging
import re
import wave
from enum import Enum

import av
import numpy as np
import re
from fractions import Fraction
from typing import (
Dict,
Any,
AsyncIterator,
Callable,
Optional,
Union,
Dict,
Iterator,
AsyncIterator,
Literal,
Optional,
Union,
)

import logging
import aiortc
import av
import numpy as np
from aiortc import MediaStreamTrack
from aiortc.mediastreams import MediaStreamError
from numpy.typing import NDArray

from getstream.video.rtc.g711 import (
ALAW_DECODE_TABLE,
MULAW_DECODE_TABLE,
G711Encoding,
G711Mapping,
MULAW_DECODE_TABLE,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -535,8 +535,6 @@ def from_av_frame(cls, frame: "av.AudioFrame") -> "PcmData":
# Convert time_base from Fraction to float if present
time_base = None
if hasattr(frame, "time_base") and frame.time_base is not None:
from fractions import Fraction

if isinstance(frame.time_base, Fraction):
time_base = float(frame.time_base)
else:
Expand Down Expand Up @@ -810,7 +808,7 @@ def to_av_frame(self) -> "av.AudioFrame":
layout = "mono" if pcm_formatted.channels == 1 else "stereo"
frame = av.AudioFrame.from_ndarray(samples, format=av_format, layout=layout)
frame.sample_rate = pcm_formatted.sample_rate

frame.pts = pcm_formatted.pts
return frame

def g711_bytes(
Expand Down Expand Up @@ -1668,6 +1666,170 @@ def head(
participant=self.participant,
)

@property
def empty(self) -> bool:
return len(self.samples) == 0


class PyAVResampler:
"""
A stateful audio resampler.
It acts as a thin wrapper around `pyav.AudioResampler`, and it is intended to be
created once for the audio track and re-used.


Key differences from the stateless implementation:

- `pyav.AudioResampler` buffers samples internally, so the number of output samples doesn't always match the input.
- `PyAVResampler` keeps its own monotonic PTS clock, and it's meant to be used with a single audio stream only.
It ignores the PTS/DTS from `PcmData`.
PTS always starts from 0 for the first output.
- `pyav.AudioResampler` configures itself based on the first input frame. Feeding data in a different format
or sample rate will fail.
- `PyAVResampler` is not thread-safe.

The source PCMs must have the same sample rate, format, and number of channels.

Example:

>>> import numpy as np
>>> resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1)
>>> # Process 20ms chunks at 16kHz (320 samples each)
>>> samples = np.random.randint(-1000, 1000, 320, dtype=np.int16)
>>> pcm_16k = PcmData(samples=samples, sample_rate=16000, format="s16", channels=1)
>>> pcm_48k = resampler.resample(pcm_16k) # Returns 912 samples at 48kHz without flushing
>>> len(pcm_48k.samples)
912
>>> flushed_pcm = resampler.flush()
>>> len(flushed_pcm.samples)
48
"""

def __init__(
self,
format: AudioFormatType,
sample_rate: int,
channels: int,
frame_size: int = 0,
):
"""
Initialize a stateful resampler with target audio parameters.

Args:
format: Target format ("s16" or "f32", also `AudioFormat.F32` or `AudioFormat.S16`).
sample_rate: Target sample rate (e.g., 48000, 16000, 8000)
channels: Target number of channels (1 for mono, 2 for stereo)
frame_size: how many samples per channel are produce in each output frame.
When set, the underlying resampler will buffer output the specified number of samples is accumulated,
and it will output frames of this exact size (except when ".resample(flush=True)").
Default - `0` (each frame can be of a variable size).
"""
if isinstance(format, str):
AudioFormat.validate(format)

self.format = AudioFormat(format)
self.sample_rate = sample_rate
self.channels = channels
self.frame_size = frame_size
# Determine PyAV format based on original format to preserve it
# f32 -> fltp (float32 planar), s16 -> s16p (int16 planar)
self._pyav_format = "fltp" if self.format == AudioFormat.F32 else "s16p"
self._pts = 0
self._set_pyav_resampler()

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(format={self.format.name.lower()!r}, "
f"sample_rate={self.sample_rate}, channels={self.channels}, frame_size={self.frame_size})"
)

def _set_pyav_resampler(self):
# Create PyAV resampler with format matching the original
self._pyav_resampler = av.AudioResampler(
format=self._pyav_format,
layout="mono" if self.channels == 1 else "stereo",
rate=self.sample_rate,
frame_size=self.frame_size,
)

def _pyav_resample(self, frame: av.AudioFrame | None) -> list[av.AudioFrame]:
if frame is not None and not frame.samples:
# pyav resampler fails if audioframe has no samples
return []
return self._pyav_resampler.resample(frame)

def resample(self, pcm: PcmData, flush: bool = False) -> PcmData:
"""
Resample using PyAV (libav) for high-quality resampling and downmixing.

Args:
pcm: Input PCM data to resample
flush: if True, get the remaining frames from underlying `av.AudioResampler` if there are any.
Default - `False`.

Returns:
New PcmData object with resampled audio, potentially empty if the frame size is set and larger
than the input PCM.
"""
# Create AudioFrame from PcmData
frame = pcm.to_av_frame()

# Convert each frame to PcmData using from_av_frame and concatenate them
# Start with an empty PcmData preserving the original format
result = PcmData(
sample_rate=self.sample_rate,
format=self.format,
channels=self.channels,
time_base=1 / self.sample_rate,
)

# Resample
# Keep the lock because resampler is stateful, and we want to keep PTS in order
resampled_frames = self._pyav_resample(frame)
if flush:
try:
resampled_frames.extend(self._pyav_resample(None))
finally:
# Reset the resampler because it cannot be used after it's flushed,
self._set_pyav_resampler()

for resampled_frame in resampled_frames:
self._pts += resampled_frame.samples
result = result.append(PcmData.from_av_frame(resampled_frame))

result.pts = self._pts - len(result.samples)
result.dts = result.pts
return result

def flush(self) -> PcmData:
"""
Flush the underlying `av.AudioResampler`

Returns:
New PcmData object with resampled audio, potentially empty.
"""
# Convert each frame to PcmData using from_av_frame and concatenate them
# Start with an empty PcmData preserving the original format
result = PcmData(
sample_rate=self.sample_rate,
format=self.format,
channels=self.channels,
)

try:
# Flush the resampler to get remaining buffered samples
resampled_frames = self._pyav_resample(None)
finally:
# Reset the resampler because it cannot be used after it's flushed,
self._set_pyav_resampler()

# Convert frames to PcmData and update the PTS clock
for resampled_frame in resampled_frames:
self._pts += resampled_frame.samples
result = result.append(PcmData.from_av_frame(resampled_frame))
result.pts = self._pts - len(result.samples)
return result


class Resampler:
"""
Expand Down Expand Up @@ -1867,7 +2029,7 @@ def _adjust_format(

def __repr__(self) -> str:
return (
f"Resampler(format={self.format!r}, "
f"{self.__class__.__name__}(format={self.format!r}, "
f"sample_rate={self.sample_rate}, channels={self.channels})"
)

Expand Down
Loading