Epsilon617
add model for offline mode run
8c952bb
raw
history blame
3.29 kB
import gradio as gr
#
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio
import torchaudio.transforms as T
import logging
import importlib
modeling_MERT = importlib.import_module("MERT-v0-public.modeling_MERT")
# input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
logger = logging.getLogger("whisper-jax-app")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
ch.setFormatter(formatter)
logger.addHandler(ch)
inputs = [gr.components.Audio(type="filepath", label="Add music audio file"),
gr.components.Audio(source="microphone",optional=True, type="filepath"),
]
outputs = [gr.components.Textbox()]
# outputs = [gr.components.Textbox(), transcription_df]
title = "Output the tags of a (music) audio"
description = "An example of using MERT-95M-public to conduct music tagging."
article = ""
audio_examples = [
# ["input/example-1.wav"],
# ["input/example-2.wav"],
]
# Load the model and the corresponding preprocessor config
# model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
# processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
model = modeling_MERT.MERTModel.from_pretrained("./MERT-v0-public")
processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v0-public")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
def convert_audio(inputs, microphone):
if (microphone is not None):
inputs = microphone
waveform, sample_rate = torchaudio.load(inputs)
resample_rate = processor.sampling_rate
# make sure the sample_rate aligned
if resample_rate != sample_rate:
print(f'setting rate from {sample_rate} to {resample_rate}')
resampler = T.Resample(sample_rate, resample_rate)
waveform = resampler(waveform)
waveform = waveform.view(-1,) # make it (n_sample, )
model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
model_inputs.to(device)
with torch.no_grad():
model_outputs = model(**model_inputs, output_hidden_states=True)
# take a look at the output shape, there are 13 layers of representation
# each layer performs differently in different downstream tasks, you should choose empirically
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()
# print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
# logger.warning(all_layer_hidden_states.shape)
return device + " :" + str(all_layer_hidden_states.shape)
# iface = gr.Interface(fn=convert_audio, inputs="audio", outputs="text")
# iface.launch()
audio_chunked = gr.Interface(
fn=convert_audio,
inputs=inputs,
outputs=outputs,
allow_flagging="never",
title=title,
description=description,
article=article,
examples=audio_examples,
)
demo = gr.Blocks()
with demo:
gr.TabbedInterface([audio_chunked], [
"Audio File"])
# demo.queue(concurrency_count=1, max_size=5)
demo.launch(show_api=False)