"""This recipe to train CLAP. It supports distillation using tinyCLAP (https://arxiv.org/abs/2311.14517). Authors * Francesco Paissan 2024 """ import sys import gradio as gr import speechbrain as sb import torch import torch.nn as nn import torch.nn.functional as F import torchaudio import torchaudio.transforms as T from hyperpyyaml import load_hyperpyyaml from speechbrain.utils.distributed import run_on_main from speechbrain.utils.metric_stats import MetricStats torch.backends.cudnn.enabled = False eps = 1e-10 class CLAPBrain(sb.Brain): def preprocess(self, wavs): """Pre-process wavs.""" x = self.hparams.spectrogram_extractor(wavs) x = self.hparams.logmel_extractor(x) return x def prepare_txt_features(self, text): """Prepares text features to input in CLAP text encoder.""" txt_inp = self.hparams.txt_tokenizer( text, max_length=self.hparams.text_max_length, padding="max_length", truncation=True, return_tensors="pt", ).to(self.device) return txt_inp def compute_sim(self, audio_embed, caption_embed): """Computes CLAP similarity metric.""" similarity = audio_embed @ caption_embed.t() return similarity def compute_forward(self, batch, stage): if len(batch) == 2: wavs, caption = batch else: wavs, caption, _, _ = batch wavs = wavs.to(self.device).squeeze(1) x_sb = self.preprocess(wavs) text_inp = self.prepare_txt_features(caption) txt_shared, aud_shared = self.hparams.clap( x_sb, text_inp.input_ids.data, text_inp.token_type_ids.data, text_inp.attention_mask.data, ) if not hasattr(self.modules, "clap"): aud_shared_student, _, _ = self.modules.clap_student(x_sb) aud_shared_student = aud_shared_student / aud_shared_student.norm( dim=1, keepdim=True ) return txt_shared, aud_shared, aud_shared_student def audio_preprocess(x, sample_rate): tmp, sr = torchaudio.load(x) resample = T.Resample(sr, sample_rate) tmp = resample(tmp) tmp = tmp.sum(0, keepdims=True) return tmp @torch.no_grad() def inference_wrapper(clap_brain): def f(wav_path, prompt): clap_brain.modules.eval() tmp = audio_preprocess(wav_path, clap_brain.hparams.sample_rate) ret = clap_brain.compute_forward([tmp, prompt], stage=sb.Stage.TEST) sim = clap_brain.compute_sim(ret[2], ret[0]) return f"tinyCLAP similarity is: {round(sim.item(), 2)}" return f if __name__ == "__main__": # CLI: # hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) hparams_file = "hparams/inference.yaml" # Load hyperparameters file with command-line overrides with open(hparams_file) as fin: hparams = load_hyperpyyaml(fin, {}) # Tensorboard logging if hparams["use_tensorboard"]: from speechbrain.utils.train_logger import TensorboardLogger hparams["tensorboard_train_logger"] = TensorboardLogger( hparams["tensorboard_logs_folder"] ) hparams["clap"].to(hparams["device"]) hparams["clap"].requires_grad_(False) hparams["clap"].eval() if hparams["zs_eval"]: hparams["class_list"] = datasets["train"].dataset.classes if hparams["audioenc_name_student"] is not None: if hparams["projection_only"]: print("Freezing Base AudioEncoder. Updating only the projection layers.") hparams["student_model"].base.requires_grad_(False) hparams["spectrogram_extractor"].to(hparams["device"]) hparams["logmel_extractor"].to(hparams["device"]) clap_brain = CLAPBrain( modules=hparams["modules"], hparams=hparams, ) if hparams["pretrained_CLAP"] is not None: print("Loading CLAP model...") run_on_main(hparams["load_CLAP"].collect_files) hparams["load_CLAP"].load_collected() inference_api = inference_wrapper(clap_brain) examples_list = [ ["./tunztunz_music.wav", "this is the sound of house music"], ["./siren.wav", "this is the sound of sirens wailing"], [ "./whistling_and_chirping.wav", "someone is whistling while birds are chirping", ], ] demo = gr.Interface( fn=inference_api, inputs=[gr.Audio(type="filepath"), gr.Textbox()], outputs=["text"], examples=examples_list, ) demo.launch()