tinyCLAP / app.py
fpaissan's picture
tinyCLAP Space
5e02fce
raw
history blame contribute delete
No virus
4.63 kB
"""This recipe to train CLAP.
It supports distillation using tinyCLAP (https://arxiv.org/abs/2311.14517).
Authors
* Francesco Paissan 2024
"""
import sys
import gradio as gr
import speechbrain as sb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.distributed import run_on_main
from speechbrain.utils.metric_stats import MetricStats
torch.backends.cudnn.enabled = False
eps = 1e-10
class CLAPBrain(sb.Brain):
def preprocess(self, wavs):
"""Pre-process wavs."""
x = self.hparams.spectrogram_extractor(wavs)
x = self.hparams.logmel_extractor(x)
return x
def prepare_txt_features(self, text):
"""Prepares text features to input in CLAP text encoder."""
txt_inp = self.hparams.txt_tokenizer(
text,
max_length=self.hparams.text_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(self.device)
return txt_inp
def compute_sim(self, audio_embed, caption_embed):
"""Computes CLAP similarity metric."""
similarity = audio_embed @ caption_embed.t()
return similarity
def compute_forward(self, batch, stage):
if len(batch) == 2:
wavs, caption = batch
else:
wavs, caption, _, _ = batch
wavs = wavs.to(self.device).squeeze(1)
x_sb = self.preprocess(wavs)
text_inp = self.prepare_txt_features(caption)
txt_shared, aud_shared = self.hparams.clap(
x_sb,
text_inp.input_ids.data,
text_inp.token_type_ids.data,
text_inp.attention_mask.data,
)
if not hasattr(self.modules, "clap"):
aud_shared_student, _, _ = self.modules.clap_student(x_sb)
aud_shared_student = aud_shared_student / aud_shared_student.norm(
dim=1, keepdim=True
)
return txt_shared, aud_shared, aud_shared_student
def audio_preprocess(x, sample_rate):
tmp, sr = torchaudio.load(x)
resample = T.Resample(sr, sample_rate)
tmp = resample(tmp)
tmp = tmp.sum(0, keepdims=True)
return tmp
@torch.no_grad()
def inference_wrapper(clap_brain):
def f(wav_path, prompt):
clap_brain.modules.eval()
tmp = audio_preprocess(wav_path, clap_brain.hparams.sample_rate)
ret = clap_brain.compute_forward([tmp, prompt], stage=sb.Stage.TEST)
sim = clap_brain.compute_sim(ret[2], ret[0])
return f"tinyCLAP similarity is: {round(sim.item(), 2)}"
return f
if __name__ == "__main__":
# CLI:
# hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
hparams_file = "hparams/inference.yaml"
# Load hyperparameters file with command-line overrides
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, {})
# Tensorboard logging
if hparams["use_tensorboard"]:
from speechbrain.utils.train_logger import TensorboardLogger
hparams["tensorboard_train_logger"] = TensorboardLogger(
hparams["tensorboard_logs_folder"]
)
hparams["clap"].to(hparams["device"])
hparams["clap"].requires_grad_(False)
hparams["clap"].eval()
if hparams["zs_eval"]:
hparams["class_list"] = datasets["train"].dataset.classes
if hparams["audioenc_name_student"] is not None:
if hparams["projection_only"]:
print("Freezing Base AudioEncoder. Updating only the projection layers.")
hparams["student_model"].base.requires_grad_(False)
hparams["spectrogram_extractor"].to(hparams["device"])
hparams["logmel_extractor"].to(hparams["device"])
clap_brain = CLAPBrain(
modules=hparams["modules"],
hparams=hparams,
)
if hparams["pretrained_CLAP"] is not None:
print("Loading CLAP model...")
run_on_main(hparams["load_CLAP"].collect_files)
hparams["load_CLAP"].load_collected()
inference_api = inference_wrapper(clap_brain)
examples_list = [
["./tunztunz_music.wav", "this is the sound of house music"],
["./siren.wav", "this is the sound of sirens wailing"],
[
"./whistling_and_chirping.wav",
"someone is whistling while birds are chirping",
],
]
demo = gr.Interface(
fn=inference_api,
inputs=[gr.Audio(type="filepath"), gr.Textbox()],
outputs=["text"],
examples=examples_list,
)
demo.launch()