|
import gradio as gr |
|
|
|
from transformers import Wav2Vec2FeatureExtractor |
|
from transformers import AutoModel |
|
import torch |
|
from torch import nn |
|
import torchaudio |
|
import torchaudio.transforms as T |
|
import logging |
|
|
|
import json |
|
|
|
import importlib |
|
modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT") |
|
|
|
from Prediction_Head.MTGGenre_head import MLPProberBase |
|
|
|
|
|
|
|
logger = logging.getLogger("whisper-jax-app") |
|
logger.setLevel(logging.INFO) |
|
ch = logging.StreamHandler() |
|
ch.setLevel(logging.INFO) |
|
formatter = logging.Formatter( |
|
"%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S") |
|
ch.setFormatter(formatter) |
|
logger.addHandler(ch) |
|
|
|
|
|
|
|
inputs = [ |
|
gr.components.Audio(type="filepath", label="Add music audio file"), |
|
gr.inputs.Audio(source="microphone", type="filepath"), |
|
] |
|
live_inputs = [ |
|
gr.Audio(source="microphone",streaming=True, type="filepath"), |
|
] |
|
|
|
|
|
title = "Predict the top 5 possible genres and tags of Music" |
|
description = "An example of using map/MERT-95M-public model as backbone to conduct music genre/tagging predcition." |
|
article = "" |
|
audio_examples = [ |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public") |
|
processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public") |
|
|
|
MERT_LAYER_IDX = 7 |
|
MTGGenre_classifier = MLPProberBase() |
|
MTGGenre_classifier.load_state_dict(torch.load('Prediction_Head/best_MTGGenre.ckpt')['state_dict']) |
|
|
|
with open('Prediction_Head/MTGGenre_id2class.json', 'r') as f: |
|
id2cls=json.load(f) |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model.to(device) |
|
MTGGenre_classifier.to(device) |
|
|
|
def convert_audio(inputs, microphone): |
|
if (microphone is not None): |
|
inputs = microphone |
|
|
|
waveform, sample_rate = torchaudio.load(inputs) |
|
|
|
resample_rate = processor.sampling_rate |
|
|
|
|
|
if resample_rate != sample_rate: |
|
print(f'setting rate from {sample_rate} to {resample_rate}') |
|
resampler = T.Resample(sample_rate, resample_rate) |
|
waveform = resampler(waveform) |
|
|
|
waveform = waveform.view(-1,) |
|
model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt") |
|
model_inputs.to(device) |
|
with torch.no_grad(): |
|
model_outputs = model(**model_inputs, output_hidden_states=True) |
|
|
|
|
|
|
|
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze() |
|
print(all_layer_hidden_states.shape) |
|
|
|
logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) |
|
print(logits.shape) |
|
sorted_idx = torch.argsort(logits, dim = -1, descending=True) |
|
|
|
output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]]) |
|
|
|
|
|
|
|
return f"device: {device}\n" + output_texts |
|
|
|
def live_convert_audio(microphone): |
|
if (microphone is not None): |
|
inputs = microphone |
|
|
|
waveform, sample_rate = torchaudio.load(inputs) |
|
|
|
resample_rate = processor.sampling_rate |
|
|
|
|
|
if resample_rate != sample_rate: |
|
print(f'setting rate from {sample_rate} to {resample_rate}') |
|
resampler = T.Resample(sample_rate, resample_rate) |
|
waveform = resampler(waveform) |
|
|
|
waveform = waveform.view(-1,) |
|
model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt") |
|
model_inputs.to(device) |
|
with torch.no_grad(): |
|
model_outputs = model(**model_inputs, output_hidden_states=True) |
|
|
|
|
|
|
|
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze() |
|
print(all_layer_hidden_states.shape) |
|
|
|
logits = MTGGenre_classifier(torch.mean(all_layer_hidden_states[MERT_LAYER_IDX], dim=0)) |
|
print(logits.shape) |
|
sorted_idx = torch.argsort(logits, dim = -1, descending=True) |
|
|
|
output_texts = "\n".join([id2cls[str(idx.item())].replace('genre---', '') for idx in sorted_idx[:5]]) |
|
|
|
|
|
|
|
return f"device: {device}\n" + output_texts |
|
|
|
|
|
audio_chunked = gr.Interface( |
|
fn=convert_audio, |
|
inputs=inputs, |
|
outputs=[gr.components.Textbox()], |
|
allow_flagging="never", |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=audio_examples, |
|
) |
|
|
|
live_audio_chunked = gr.Interface( |
|
fn=live_convert_audio, |
|
inputs=live_inputs, |
|
outputs=[gr.components.Textbox()], |
|
allow_flagging="never", |
|
title=title, |
|
description=description, |
|
article=article, |
|
|
|
live=True, |
|
) |
|
|
|
|
|
demo = gr.Blocks() |
|
with demo: |
|
gr.TabbedInterface( |
|
[ |
|
audio_chunked, |
|
live_audio_chunked, |
|
], |
|
[ |
|
"Audio File or Recording", |
|
"Live Streaming Music" |
|
] |
|
) |
|
demo.queue(concurrency_count=1, max_size=5) |
|
demo.launch(show_api=False) |