|
import gradio as gr |
|
|
|
from transformers import Wav2Vec2FeatureExtractor |
|
from transformers import AutoModel |
|
import torch |
|
from torch import nn |
|
import torchaudio |
|
import torchaudio.transforms as T |
|
import logging |
|
|
|
import importlib |
|
modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT") |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger("whisper-jax-app") |
|
logger.setLevel(logging.INFO) |
|
ch = logging.StreamHandler() |
|
ch.setLevel(logging.INFO) |
|
formatter = logging.Formatter( |
|
"%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S") |
|
ch.setFormatter(formatter) |
|
logger.addHandler(ch) |
|
|
|
|
|
|
|
inputs = [gr.components.Audio(type="filepath", label="Add music audio file"), |
|
gr.components.Audio(source="microphone",optional=True, type="filepath"), |
|
] |
|
outputs = [gr.components.Textbox()] |
|
|
|
title = "Output the tags of a (music) audio" |
|
description = "An example of using MERT-95M-public to conduct music tagging." |
|
article = "" |
|
audio_examples = [ |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public") |
|
processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public") |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model.to(device) |
|
|
|
def convert_audio(inputs, microphone): |
|
if (microphone is not None): |
|
inputs = microphone |
|
|
|
waveform, sample_rate = torchaudio.load(inputs) |
|
|
|
resample_rate = processor.sampling_rate |
|
|
|
|
|
if resample_rate != sample_rate: |
|
print(f'setting rate from {sample_rate} to {resample_rate}') |
|
resampler = T.Resample(sample_rate, resample_rate) |
|
waveform = resampler(waveform) |
|
|
|
waveform = waveform.view(-1,) |
|
model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt") |
|
model_inputs.to(device) |
|
with torch.no_grad(): |
|
model_outputs = model(**model_inputs, output_hidden_states=True) |
|
|
|
|
|
|
|
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze() |
|
|
|
|
|
|
|
return device + " :" + str(all_layer_hidden_states.shape) |
|
|
|
|
|
|
|
|
|
|
|
audio_chunked = gr.Interface( |
|
fn=convert_audio, |
|
inputs=inputs, |
|
outputs=outputs, |
|
allow_flagging="never", |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=audio_examples, |
|
) |
|
|
|
|
|
demo = gr.Blocks() |
|
with demo: |
|
gr.TabbedInterface([audio_chunked], [ |
|
"Audio File"]) |
|
|
|
demo.launch(show_api=False) |