diff --git a/README.md b/README.md index d640cd5..0b84979 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ This repository provides code for the Waveformer architecture proposed in the paper. Waveformer is a low-latency target sound extraction model implementing streaming inference -- the model process a ~10 ms input audio chunk at each time step, while only looking at past chunks and no future chunks. On a Core i5 CPU using a single thread, real-time factors (RTFs) of different model configurations range from 0.66 to 0.94, with an end-to-end latency less than 20 ms. [![Gradio demo](https://img.shields.io/badge/arxiv-abs-green)](https://arxiv.org/abs/2211.02250) [![Gradio demo](https://img.shields.io/badge/arxiv-pdf-green)](https://arxiv.org/pdf/2211.02250) [![Gradio demo](https://img.shields.io/badge/Gradio-app-blue)](https://huggingface.co/spaces/uwx/waveformer) +[![App Gallery](https://bit.ly/3xTcccO)](https://01ghh2pnbdet9ex9sdqqsnpxwh.litng-ai-03.litng.ai/view) diff --git a/app.py b/app.py new file mode 100644 index 0000000..ab0edb0 --- /dev/null +++ b/app.py @@ -0,0 +1,80 @@ +import json +import os + +import gradio as gr +import lightning as L +import torch +import torchaudio +import wget +from lightning.app.components.serve import ServeGradio + +from Waveformer import TARGETS +from Waveformer import Waveformer as WaveformerModel + + +class ModelDemo(ServeGradio): + inputs = [ + gr.Audio(label="Input audio"), + gr.CheckboxGroup(choices=TARGETS, label="Extract target sound"), + ] + outputs = gr.Audio(label="Output audio") + examples = [["data/Sample.wav"]] + enable_queue: bool = False + + def __init__(self, *args, **kwargs): + super().__init__(cloud_compute=L.CloudCompute("cpu-medium"), **kwargs) + self._device = None + + def build_model(self): + if not os.path.exists("default_config.json"): + config_url = ( + "https://targetsound.cs.washington.edu/files/default_config.json" + ) + print("Downloading model configuration from %s:" % config_url) + wget.download(config_url) + + if not os.path.exists("default_ckpt.pt"): + ckpt_url = "https://targetsound.cs.washington.edu/files/default_ckpt.pt" + print("\nDownloading the checkpoint from %s:" % ckpt_url) + wget.download(ckpt_url) + + # Instantiate model + with open("default_config.json") as f: + params = json.load(f) + model = WaveformerModel(**params["model_params"]) + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + self._device = device + print(f"loading model on {device}") + model.load_state_dict( + torch.load("default_ckpt.pt", map_location=self._device)["model_state_dict"] + ) + return model.to(self._device).eval() + + @torch.inference_mode() + def predict(self, audio, label_choices): + # Read input audio + fs, mixture = audio + if fs!=44100: + mixture = torchaudio.functional.resample( + torch.as_tensor(mixture, dtype=torch.float32), orig_freq=fs, new_freq=44100 + ).numpy() + + mixture = torch.from_numpy(mixture).unsqueeze(0).unsqueeze(0).to( + torch.float + ) / (2.0**15) + + # Construct the query vector + query = torch.zeros(1, len(TARGETS)).to(self._device) + for t in label_choices: + query[0, TARGETS.index(t)] = 1.0 + + with torch.inference_mode(): + output = (2.0**15) * self.model(mixture.to(self._device), query) + + return fs, output.squeeze(0).squeeze(0).to(torch.short).cpu().numpy() + + +app = L.LightningApp(ModelDemo()) diff --git a/requirements.txt b/requirements.txt index a472039..81d400c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ seaborn ipykernel scaper wget +gradio \ No newline at end of file