from pathlib import Path import gradio as gr import numpy as np import os from functools import cache from pathlib import Path from models.audio_spectrogram_transformer import AST, ASTExtractorWrapper from models.training_environment import TrainingEnvironment import torch from torch import nn import yaml import torchaudio CONFIG_FILE = Path("models/config/train_local.yaml") MODEL_CLS = AST EXTRACTOR = ASTExtractorWrapper class DancePredictor: def __init__( self, weight_path: str, labels: list[str], expected_duration=6, threshold=0.5, resample_frequency=16000, device="cpu", ): super().__init__() self.expected_duration = expected_duration self.threshold = threshold self.resample_frequency = resample_frequency self.labels = np.array(labels) self.device = device self.model = self.get_model(weight_path) self.extractor = ASTExtractorWrapper() def get_model(self, weight_path: str) -> nn.Module: weights = torch.load(weight_path, map_location=self.device)["state_dict"] model = AST(self.labels).to(self.device) for key in list(weights): weights[ key.replace( "model.", "", ) ] = weights.pop(key) model.load_state_dict(weights, strict=False) return @classmethod def from_config(cls, config_path: str) -> "DancePredictor": with open(config_path, "r") as f: config = yaml.safe_load(f) weight_path = config["checkpoint"] labels = sorted(config["dance_ids"]) expected_duration = 6 threshold = 0.5 resample_frequency = 16000 device = "mps" return DancePredictor( weight_path, labels, expected_duration, threshold, resample_frequency, device, ) @torch.no_grad() def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]: if waveform.ndim == 1: waveform = np.stack([waveform, waveform]).T waveform = torch.from_numpy(waveform.T) waveform = torchaudio.functional.apply_codec( waveform, sample_rate, "wav", channels_first=True ) waveform = torchaudio.functional.resample( waveform, sample_rate, self.resample_frequency ) waveform = waveform[ :, : self.resample_frequency * self.expected_duration ] # TODO PAD features = self.extractor(waveform) features = features.unsqueeze(0).to(self.device) results = self.model(features) results = nn.functional.softmax(results.squeeze(0), dim=0) results = results.detach().cpu().numpy() result_mask = results > self.threshold probs = results[result_mask] dances = self.labels[result_mask] return {dance: float(prob) for dance, prob in zip(dances, probs)} @cache def get_model(config_path: str) -> DancePredictor: model = DancePredictor.from_config(config_path) return model def predict(audio: tuple[int, np.ndarray]) -> list[str]: sample_rate, waveform = audio model = get_model(CONFIG_FILE) results = model(waveform, sample_rate) return results if len(results) else "Dance Not Found" def demo(): title = "Dance Classifier" description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!" song_samples = Path(os.path.dirname(__file__), "assets", "song-samples") example_audio = [ str(song) for song in song_samples.iterdir() if[0] != "." ] all_dances = get_model(CONFIG_FILE).labels recording_interface = gr.Interface( fn=predict, description="Record at least **6 seconds** of the song.", inputs=gr.Audio(source="microphone", label="Song Recording"), outputs=gr.Label(label="Dances"), examples=example_audio, ) uploading_interface = gr.Interface( fn=predict, inputs=gr.Audio(label="Song Audio File"), outputs=gr.Label(label="Dances"), examples=example_audio, ) with gr.Blocks() as app: gr.Markdown(f"# {title}") gr.Markdown(description) gr.TabbedInterface( [uploading_interface, recording_interface], ["Upload Song", "Record Song"] ) with gr.Accordion("See all dances", open=False): gr.Markdown("\n".join(f"- {dance}" for dance in all_dances)) return app if __name__ == "__main__": demo().launch()