From 472edaaa9993f46ac201a3b2eebc1ed85f900bae Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Sat, 20 Dec 2025 13:57:37 +0100 Subject: [PATCH] Added `PyAVResampler` - a stateful resampler implementation --- getstream/video/rtc/track_util.py | 190 ++++++++++++++-- tests/rtc/test_pcm_data.py | 362 +++++++++++++++++++++++++++++- 2 files changed, 535 insertions(+), 17 deletions(-) diff --git a/getstream/video/rtc/track_util.py b/getstream/video/rtc/track_util.py index 92b4f0c7..f5e34217 100644 --- a/getstream/video/rtc/track_util.py +++ b/getstream/video/rtc/track_util.py @@ -2,34 +2,34 @@ import base64 import fractions import io +import logging +import re import wave from enum import Enum - -import av -import numpy as np -import re +from fractions import Fraction from typing import ( - Dict, Any, + AsyncIterator, Callable, - Optional, - Union, + Dict, Iterator, - AsyncIterator, Literal, + Optional, + Union, ) -import logging import aiortc +import av +import numpy as np from aiortc import MediaStreamTrack from aiortc.mediastreams import MediaStreamError from numpy.typing import NDArray from getstream.video.rtc.g711 import ( ALAW_DECODE_TABLE, + MULAW_DECODE_TABLE, G711Encoding, G711Mapping, - MULAW_DECODE_TABLE, ) logger = logging.getLogger(__name__) @@ -535,8 +535,6 @@ def from_av_frame(cls, frame: "av.AudioFrame") -> "PcmData": # Convert time_base from Fraction to float if present time_base = None if hasattr(frame, "time_base") and frame.time_base is not None: - from fractions import Fraction - if isinstance(frame.time_base, Fraction): time_base = float(frame.time_base) else: @@ -810,7 +808,7 @@ def to_av_frame(self) -> "av.AudioFrame": layout = "mono" if pcm_formatted.channels == 1 else "stereo" frame = av.AudioFrame.from_ndarray(samples, format=av_format, layout=layout) frame.sample_rate = pcm_formatted.sample_rate - + frame.pts = pcm_formatted.pts return frame def g711_bytes( @@ -1668,6 +1666,170 @@ def head( participant=self.participant, ) + @property + def empty(self) -> bool: + return len(self.samples) == 0 + + +class PyAVResampler: + """ + A stateful audio resampler. + It acts as a thin wrapper around `pyav.AudioResampler`, and it is intended to be + created once for the audio track and re-used. + + + Key differences from the stateless implementation: + + - `pyav.AudioResampler` buffers samples internally, so the number of output samples doesn't always match the input. + - `PyAVResampler` keeps its own monotonic PTS clock, and it's meant to be used with a single audio stream only. + It ignores the PTS/DTS from `PcmData`. + PTS always starts from 0 for the first output. + - `pyav.AudioResampler` configures itself based on the first input frame. Feeding data in a different format + or sample rate will fail. + - `PyAVResampler` is not thread-safe. + + The source PCMs must have the same sample rate, format, and number of channels. + + Example: + + >>> import numpy as np + >>> resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1) + >>> # Process 20ms chunks at 16kHz (320 samples each) + >>> samples = np.random.randint(-1000, 1000, 320, dtype=np.int16) + >>> pcm_16k = PcmData(samples=samples, sample_rate=16000, format="s16", channels=1) + >>> pcm_48k = resampler.resample(pcm_16k) # Returns 912 samples at 48kHz without flushing + >>> len(pcm_48k.samples) + 912 + >>> flushed_pcm = resampler.flush() + >>> len(flushed_pcm.samples) + 48 + """ + + def __init__( + self, + format: AudioFormatType, + sample_rate: int, + channels: int, + frame_size: int = 0, + ): + """ + Initialize a stateful resampler with target audio parameters. + + Args: + format: Target format ("s16" or "f32", also `AudioFormat.F32` or `AudioFormat.S16`). + sample_rate: Target sample rate (e.g., 48000, 16000, 8000) + channels: Target number of channels (1 for mono, 2 for stereo) + frame_size: how many samples per channel are produce in each output frame. + When set, the underlying resampler will buffer output the specified number of samples is accumulated, + and it will output frames of this exact size (except when ".resample(flush=True)"). + Default - `0` (each frame can be of a variable size). + """ + if isinstance(format, str): + AudioFormat.validate(format) + + self.format = AudioFormat(format) + self.sample_rate = sample_rate + self.channels = channels + self.frame_size = frame_size + # Determine PyAV format based on original format to preserve it + # f32 -> fltp (float32 planar), s16 -> s16p (int16 planar) + self._pyav_format = "fltp" if self.format == AudioFormat.F32 else "s16p" + self._pts = 0 + self._set_pyav_resampler() + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(format={self.format.name.lower()!r}, " + f"sample_rate={self.sample_rate}, channels={self.channels}, frame_size={self.frame_size})" + ) + + def _set_pyav_resampler(self): + # Create PyAV resampler with format matching the original + self._pyav_resampler = av.AudioResampler( + format=self._pyav_format, + layout="mono" if self.channels == 1 else "stereo", + rate=self.sample_rate, + frame_size=self.frame_size, + ) + + def _pyav_resample(self, frame: av.AudioFrame | None) -> list[av.AudioFrame]: + if frame is not None and not frame.samples: + # pyav resampler fails if audioframe has no samples + return [] + return self._pyav_resampler.resample(frame) + + def resample(self, pcm: PcmData, flush: bool = False) -> PcmData: + """ + Resample using PyAV (libav) for high-quality resampling and downmixing. + + Args: + pcm: Input PCM data to resample + flush: if True, get the remaining frames from underlying `av.AudioResampler` if there are any. + Default - `False`. + + Returns: + New PcmData object with resampled audio, potentially empty if the frame size is set and larger + than the input PCM. + """ + # Create AudioFrame from PcmData + frame = pcm.to_av_frame() + + # Convert each frame to PcmData using from_av_frame and concatenate them + # Start with an empty PcmData preserving the original format + result = PcmData( + sample_rate=self.sample_rate, + format=self.format, + channels=self.channels, + time_base=1 / self.sample_rate, + ) + + # Resample + # Keep the lock because resampler is stateful, and we want to keep PTS in order + resampled_frames = self._pyav_resample(frame) + if flush: + try: + resampled_frames.extend(self._pyav_resample(None)) + finally: + # Reset the resampler because it cannot be used after it's flushed, + self._set_pyav_resampler() + + for resampled_frame in resampled_frames: + self._pts += resampled_frame.samples + result = result.append(PcmData.from_av_frame(resampled_frame)) + + result.pts = self._pts - len(result.samples) + result.dts = result.pts + return result + + def flush(self) -> PcmData: + """ + Flush the underlying `av.AudioResampler` + + Returns: + New PcmData object with resampled audio, potentially empty. + """ + # Convert each frame to PcmData using from_av_frame and concatenate them + # Start with an empty PcmData preserving the original format + result = PcmData( + sample_rate=self.sample_rate, + format=self.format, + channels=self.channels, + ) + + try: + # Flush the resampler to get remaining buffered samples + resampled_frames = self._pyav_resample(None) + finally: + # Reset the resampler because it cannot be used after it's flushed, + self._set_pyav_resampler() + + # Convert frames to PcmData and update the PTS clock + for resampled_frame in resampled_frames: + self._pts += resampled_frame.samples + result = result.append(PcmData.from_av_frame(resampled_frame)) + result.pts = self._pts - len(result.samples) + return result + class Resampler: """ @@ -1867,7 +2029,7 @@ def _adjust_format( def __repr__(self) -> str: return ( - f"Resampler(format={self.format!r}, " + f"{self.__class__.__name__}(format={self.format!r}, " f"sample_rate={self.sample_rate}, channels={self.channels})" ) diff --git a/tests/rtc/test_pcm_data.py b/tests/rtc/test_pcm_data.py index bc5ab77f..79581ccb 100644 --- a/tests/rtc/test_pcm_data.py +++ b/tests/rtc/test_pcm_data.py @@ -1,9 +1,15 @@ +from fractions import Fraction + +import av import numpy as np import pytest -import av -from fractions import Fraction -from getstream.video.rtc.track_util import PcmData, AudioFormat, Resampler +from getstream.video.rtc.track_util import ( + AudioFormat, + PcmData, + Resampler, + PyAVResampler, +) def _i16_list_from_bytes(b: bytes): @@ -1968,3 +1974,353 @@ def test_g711_integration(tmp_path): # Check that decoded audio isn't all zeros assert np.any(pcm_decoded_mulaw.samples != 0) assert np.any(pcm_decoded_alaw.samples != 0) + + +class TestPyAVResampler: + def test_resample_upsample(self): + """Test basic resampling functionality.""" + # Create a resampler for 48kHz mono s16 + resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1) + + # Create 20ms of 16kHz audio (320 samples) + samples = np.random.randint(-1000, 1000, 320, dtype=np.int16) + pcm_16k = PcmData(samples=samples, sample_rate=16000, format="s16", channels=1) + + # Resample to 48kHz + pcm_48k = resampler.resample(pcm_16k) + + assert pcm_48k.sample_rate == 48000 + assert pcm_48k.format == "s16" + assert pcm_48k.channels == 1 + + # 16kHz to 48kHz is 3x upsampling: 320 * 3 = 960 + # Stateful resampler keeps some frames in the buffer + assert len(pcm_48k.samples) < 960 + # We need to flush the resampler to get the remaining ones + pcm_flushed = resampler.flush() + assert len(pcm_flushed.samples) + len(pcm_48k.samples) == 960 + + # Calling resample with "flush=True" returns all samples at once + assert len(resampler.resample(pcm_16k, flush=True).samples) == 960 + + def test_resample_downsample(self): + """Test downsampling from 48kHz to 16kHz.""" + resampler = PyAVResampler(format="s16", sample_rate=16000, channels=1) + + # Create 48kHz audio (960 samples = 20ms) + samples = np.random.randint(-1000, 1000, 960, dtype=np.int16) + pcm_48k = PcmData(samples=samples, sample_rate=48000, format="s16", channels=1) + + # Downsample to 16kHz + pcm_16k = resampler.resample(pcm_48k) + + assert pcm_16k.sample_rate == 16000 + # 48kHz to 16kHz is 1/3x: 960 / 3 = 320 + # Stateful resampler keeps some frames in the buffer + assert len(pcm_16k.samples) < 320 + # We need to flush the resampler to get the remaining ones + pcm_flushed = resampler.flush() + assert len(pcm_flushed.samples) + len(pcm_16k.samples) == 320 + + # Calling resample with "flush=True" returns all samples at once + assert len(resampler.resample(pcm_48k, flush=True).samples) == 320 + + def test_resample_no_change(self): + """Test that resampler returns same data when no resampling needed.""" + resampler = PyAVResampler(format="s16", sample_rate=16000, channels=1) + + samples = np.array([1, 2, 3, 4], dtype=np.int16) + pcm = PcmData(samples=samples, sample_rate=16000, format="s16", channels=1) + + result = resampler.resample(pcm) + + assert result.sample_rate == 16000 + assert result.format == "s16" + assert result.channels == 1 + np.testing.assert_array_equal(result.samples, samples) + + def test_resample_different_sample_rates_fails(self): + # Create a resampler for 48kHz mono s16 + resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1) + + samples = np.random.randint(-1000, 1000, 320, dtype=np.int16) + # Create 20ms of 16kHz audio (320 samples) + pcm_16k = PcmData(samples=samples, sample_rate=16000, format="s16", channels=1) + # Create 40ms of 16kHz audio (320 samples) + pcm_8k = PcmData(samples=samples, sample_rate=8000, format="s16", channels=1) + + resampler.resample(pcm_16k) + # Feeding a different sample rate fails + with pytest.raises( + ValueError, match="Frame does not match AudioResampler setup" + ): + resampler.resample(pcm_8k) + + def test_resample_different_formats_fails(self): + # Create a resampler for 48kHz mono s16 + resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1) + + # Create 20ms of 16kHz audio (320 samples) + pcm_16k_s16 = PcmData( + samples=np.linspace(-1000, 1000, 320, dtype=np.int16), + sample_rate=16000, + format="s16", + channels=1, + ) + pcm_16k_f32 = PcmData( + samples=np.linspace(-1000, 1000, 320, dtype=np.float32), + sample_rate=16000, + format="f32", + channels=1, + ) + + resampler.resample(pcm_16k_s16) + # Feeding a different format fails + with pytest.raises( + ValueError, match="Frame does not match AudioResampler setup" + ): + resampler.resample(pcm_16k_f32) + + def test_resample_upsample_frame_size_set(self): + """Test that resampler accumulates samples before returning if frame size is set.""" + # Create a resampler for 8kHz mono s16 + # with a frame size 160 (20ms) + resampler = PyAVResampler( + format="s16", + sample_rate=8000, + channels=1, + frame_size=160, + ) + + # Create 100 samples of 16kHz audio + samples = np.random.randint(-1000, 1000, 100, dtype=np.int16) + pcm_16k = PcmData(samples=samples, sample_rate=16000, format="s16", channels=1) + + # Resample to 8kHz + pcm_8k = resampler.resample(pcm_16k) + + assert pcm_8k.sample_rate == 8000 + assert pcm_8k.format == "s16" + assert pcm_8k.channels == 1 + + # 100 samples are not enough for a frame of 160 8khz + assert len(pcm_8k.samples) == 0 + # Feed it 3 more times to get some output + resampler.resample(pcm_16k) + resampler.resample(pcm_16k) + pcm_8k = resampler.resample(pcm_16k) + assert len(pcm_8k.samples) == 160 + + # Flush resampler to get the last 40 samples + assert len(resampler.flush().samples) == 40 + + def test_resample_mono_to_stereo(self): + """Test resampling with channel conversion from mono to stereo.""" + resampler = PyAVResampler(format="s16", sample_rate=48000, channels=2) + + # Create 64 samples of mono 16kHz audio + mono_samples = np.linspace(-100, 100, num=64, dtype=np.int16) + pcm_mono = PcmData( + samples=mono_samples, sample_rate=16000, format="s16", channels=1 + ) + + # Resample to stereo 48kHz + pcm_stereo = resampler.resample(pcm_mono, flush=True) + + assert pcm_stereo.sample_rate == 48000 + assert pcm_stereo.channels == 2 + assert pcm_stereo.samples.shape[0] == 2 # 2 channels + # Both channels should have the same data (duplicated from mono) + np.testing.assert_array_equal(pcm_stereo.samples[0], pcm_stereo.samples[1]) + + def test_resample_stereo_to_mono(self): + """Test resampling with channel conversion from stereo to mono.""" + resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1) + + # Create stereo 16kHz audio + left_channel = np.linspace(-100, 100, num=64, dtype=np.int16) + right_channel = np.linspace(-100, 100, num=64, dtype=np.int16) + stereo_samples = np.vstack([left_channel, right_channel]) + pcm_stereo = PcmData( + samples=stereo_samples, sample_rate=16000, format="s16", channels=2 + ) + + # Resample to mono 48kHz + pcm_mono = resampler.resample(pcm_stereo) + + assert pcm_mono.sample_rate == 48000 + assert pcm_mono.channels == 1 + assert pcm_mono.samples.ndim == 1 # 1D array for mono + + def test_resample_format_conversion_to_f32(self): + """Test format conversion from s16 to f32.""" + resampler = PyAVResampler(format="f32", sample_rate=16000, channels=1) + + # Create s16 audio + samples = np.array([0, 16384, -16384, 32767, -32768], dtype=np.int16) + pcm_s16 = PcmData(samples=samples, sample_rate=16000, format="s16", channels=1) + + # Convert to f32 + pcm_f32 = resampler.resample(pcm_s16) + + assert pcm_f32.format == "f32" + assert pcm_f32.samples.dtype == np.float32 + # Check value ranges are properly scaled to [-1, 1] + assert -1.0 <= pcm_f32.samples.min() <= 1.0 + assert -1.0 <= pcm_f32.samples.max() <= 1.0 + + def test_resample_format_conversion_to_s16(self): + """Test format conversion from f32 to s16.""" + resampler = PyAVResampler(format="s16", sample_rate=16000, channels=1) + + # Create f32 audio + samples = np.array([0.0, 0.5, -0.5, 1.0, -1.0], dtype=np.float32) + pcm_f32 = PcmData(samples=samples, sample_rate=16000, format="f32", channels=1) + + # Convert to s16 + pcm_s16 = resampler.resample(pcm_f32) + + assert pcm_s16.format == "s16" + assert pcm_s16.samples.dtype == np.int16 + # Check values are in int16 range + assert -32768 <= pcm_s16.samples.min() <= 32767 + assert -32768 <= pcm_s16.samples.max() <= 32767 + + def test_resample_20ms_chunks(self): + """Test resampling of consecutive 20ms chunks (simulating real-time streaming).""" + resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1) + + # Simulate 5 consecutive 20ms chunks at 16kHz + chunks = [] + for i in range(5): + # 20ms at 16kHz = 320 samples + samples = ( + np.sin(2 * np.pi * 440 * (np.arange(320) + i * 320) / 16000) * 10000 + ) + samples = samples.astype(np.int16) + pcm = PcmData(samples=samples, sample_rate=16000, format="s16", channels=1) + chunks.append(pcm) + + # Resample each chunk independently (simulating real-time processing) + resampled_chunks = [] + for chunk in chunks: + resampled = resampler.resample(chunk, flush=True) + resampled_chunks.append(resampled) + # Each 20ms chunk at 48kHz should be 960 samples + assert len(resampled.samples) == 960 + assert resampled.sample_rate == 48000 + + # Verify no state is maintained between chunks by checking each is processed identically + # Two identical input chunks should produce identical outputs + identical_chunk = PcmData( + samples=chunks[0].samples, sample_rate=16000, format="s16", channels=1 + ) + resampled1 = resampler.resample(chunks[0], flush=True) + resampled2 = resampler.resample(identical_chunk, flush=True) + np.testing.assert_array_equal(resampled1.samples, resampled2.samples) + + def test_resample_time_handling(self): + """ + Test that PTS/DTS timestamps increase monotonically after resampling. + """ + resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1) + + samples = np.linspace(-100, 100, num=320, dtype=np.int16) + pcm = PcmData( + samples=samples, + sample_rate=16000, + format="s16", + channels=1, + pts=1234, + dts=1230, + time_base=0.001, + ) + + resampled = resampler.resample(pcm, flush=True) + + # Stateful resampler keeps track of the time on its own. + # It resets pts when getting the first sample, and it tracks the new pts value + # based on the number of output samples. + # The values from PcmData objects are ignored. + assert resampled.pts == 0 + assert resampled.dts == 0 + assert resampled.time_base == 1 / resampler.sample_rate + + resampled = resampler.resample(pcm, flush=True) + assert resampled.pts == 960 + assert resampled.dts == 960 + + def test_repr(self): + """Test string representation of Resampler.""" + resampler = PyAVResampler( + format="f32", sample_rate=44100, channels=2, frame_size=1024 + ) + repr_str = repr(resampler) + + assert "format='f32'" in repr_str + assert "sample_rate=44100" in repr_str + assert "channels=2" in repr_str + assert "frame_size=1024" in repr_str + + def test_resample_empty_audio(self): + """Test edge cases like empty audio, single sample, etc.""" + resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1) + + # Empty audio + empty_pcm = PcmData( + samples=np.array([], dtype=np.int16), + sample_rate=16000, + format="s16", + channels=1, + ) + resampled_empty = resampler.resample(empty_pcm) + assert len(resampled_empty.samples) == 0 + + def test_flush_empty(self): + """Test flushing the resampler before processing any samples.""" + resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1) + + # Empty audio + resampled_empty = resampler.flush() + assert len(resampled_empty.samples) == 0 + + def test_resample_consistency_across_chunks(self): + """Test that splitting audio and processing in chunks gives consistent results.""" + resampler = PyAVResampler(format="s16", sample_rate=48000, channels=1) + + # Create a longer audio signal + total_samples = 1600 # 100ms at 16kHz + samples = np.sin(2 * np.pi * 440 * np.arange(total_samples) / 16000) * 10000 + samples = samples.astype(np.int16) + + # Process as one chunk and flush immediately + pcm_full = PcmData(samples=samples, sample_rate=16000, format="s16", channels=1) + resampled_full = resampler.resample(pcm_full, flush=True) + + # Process as multiple 20ms chunks + chunk_size = 320 # 20ms at 16kHz + resampled_chunks = [] + for i in range(0, total_samples, chunk_size): + chunk_samples = samples[i : i + chunk_size] + pcm_chunk = PcmData( + samples=chunk_samples, sample_rate=16000, format="s16", channels=1 + ) + resampled_chunk = resampler.resample(pcm_chunk) + resampled_chunks.append(resampled_chunk.samples) + + # Flush the resampler after processing all the chunks + resampled_chunks.append(resampler.flush().samples) + + # Concatenate chunks + resampled_concatenated = np.concatenate(resampled_chunks) + + # The results should be similar, though with some differences due to + # independent chunk processing. The stateless resampler uses endpoint + # mapping for each chunk, which prevents out-of-bounds access but creates + # phase differences compared to processing as one continuous signal. + assert len(resampled_full.samples) == len(resampled_concatenated) + # Check that the difference is reasonable for independent chunk processing + diff = np.abs(resampled_full.samples - resampled_concatenated) + assert ( + np.mean(diff) < 250 + ) # Allow for phase differences in stateless processing