TobDeBer's picture
description update
06fa0ef
import gradio as gr
#
from transformers import Wav2Vec2FeatureExtractor
from transformers import AutoModel
import torch
from torch import nn
import torchaudio
import torchaudio.transforms as T
import logging
import json
import os
import re
import pandas as pd
import importlib
modeling_MERT = importlib.import_module("MERT-v1-95M.modeling_MERT")
from Prediction_Head.MTGGenre_head import MLPProberBase
# input cr: https://huggingface.co/spaces/thealphhamerc/audio-to-text/blob/main/app.py
logger = logging.getLogger("MERT-v1-95M-app")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
ch.setFormatter(formatter)
logger.addHandler(ch)
inputs = [
gr.components.Audio(type="filepath", label="Add music audio file"),
gr.inputs.Audio(source="microphone", type="filepath"),
]
live_inputs = [
gr.Audio(source="microphone",streaming=True, type="filepath"),
]
title = "One Model for All Music Understanding Tasks"
description = "An example of using the [MERT-v1-95M](https://huggingface.co/m-a-p/MERT-v1-95M) model as backbone to conduct multiple music understanding tasks with the universal representation. \n Due the hardware limitation of the machine hosting this demo (2 CPU and 16GB RAM) only the first 4 seconds of audio are used!"
# article = "The tasks include EMO, GS, MTGInstrument, MTGGenre, MTGTop50, MTGMood, NSynthI, NSynthP, VocalSetS, VocalSetT. \n\n More models can be referred at the [map organization page](https://huggingface.co/m-a-p)."
with open('./README.md', 'r') as f:
# skip the header
header_count = 0
for line in f:
if '---' in line:
header_count += 1
if header_count >= 2:
break
# read the rest conent
article = f.read()
audio_examples = [
# ["input/example-1.wav"],
# ["input/example-2.wav"],
]
df_init = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5'])
transcription_df = gr.DataFrame(value=df_init, label="Output Dataframe", row_count=(
0, "dynamic"), max_rows=30, wrap=True, overflow_row_behaviour='paginate')
# outputs = [gr.components.Textbox()]
outputs = transcription_df
df_init_live = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5'])
transcription_df_live = gr.DataFrame(value=df_init_live, label="Output Dataframe", row_count=(
0, "dynamic"), max_rows=30, wrap=True, overflow_row_behaviour='paginate')
outputs_live = transcription_df_live
# Load the model and the corresponding preprocessor config
# model = AutoModel.from_pretrained("m-a-p/MERT-v0-public", trust_remote_code=True)
# processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v0-public",trust_remote_code=True)
model = modeling_MERT.MERTModel.from_pretrained("./MERT-v1-95M")
processor = Wav2Vec2FeatureExtractor.from_pretrained("./MERT-v1-95M")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MERT_BEST_LAYER_IDX = {
'EMO': 5,
'GS': 8,
'GTZAN': 7,
'MTGGenre': 7,
'MTGInstrument': 'all',
'MTGMood': 6,
'MTGTop50': 6,
'MTT': 'all',
'NSynthI': 6,
'NSynthP': 1,
'VocalSetS': 2,
'VocalSetT': 9,
}
MERT_BEST_LAYER_IDX = {
'EMO': 5,
'GS': 8,
'GTZAN': 7,
'MTGGenre': 7,
'MTGInstrument': 'all',
'MTGMood': 6,
'MTGTop50': 6,
'MTT': 'all',
'NSynthI': 6,
'NSynthP': 1,
'VocalSetS': 2,
'VocalSetT': 9,
}
CLASSIFIERS = {
}
ID2CLASS = {
}
TASKS = ['GS', 'MTGInstrument', 'MTGGenre', 'MTGTop50', 'MTGMood', 'NSynthI', 'NSynthP', 'VocalSetS', 'VocalSetT','EMO',]
Regression_TASKS = ['EMO']
head_dir = './Prediction_Head/best-layer-MERT-v1-95M'
for task in TASKS:
print('loading', task)
with open(os.path.join(head_dir,f'{task}.id2class.json'), 'r') as f:
ID2CLASS[task]=json.load(f)
num_class = len(ID2CLASS[task].keys())
CLASSIFIERS[task] = MLPProberBase(d=768, layer=MERT_BEST_LAYER_IDX[task], num_outputs=num_class)
CLASSIFIERS[task].load_state_dict(torch.load(f'{head_dir}/{task}.ckpt')['state_dict'])
CLASSIFIERS[task].to(device)
model.to(device)
def model_infernce(inputs):
waveform, sample_rate = torchaudio.load(inputs)
resample_rate = processor.sampling_rate
# make sure the sample_rate aligned
if resample_rate != sample_rate:
# print(f'setting rate from {sample_rate} to {resample_rate}')
resampler = T.Resample(sample_rate, resample_rate)
waveform = resampler(waveform)
#waveform = waveform.view(-1,) # make it (n_sample, )
waveform = waveform[0][0:4*resample_rate] # cut to 4s samples
model_inputs = processor(waveform, sampling_rate=resample_rate, return_tensors="pt")
model_inputs.to(device)
with torch.no_grad():
model_outputs = model(**model_inputs, output_hidden_states=True)
# take a look at the output shape, there are 13 layers of representation
# each layer performs differently in different downstream tasks, you should choose empirically
all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()[1:,:,:].unsqueeze(0)
print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
all_layer_hidden_states = all_layer_hidden_states.mean(dim=2)
task_output_texts = ""
df = pd.DataFrame(columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5'])
df_objects = []
for task in TASKS:
num_class = len(ID2CLASS[task].keys())
if MERT_BEST_LAYER_IDX[task] == 'all':
logits = CLASSIFIERS[task](all_layer_hidden_states) # [1, 87]
else:
logits = CLASSIFIERS[task](all_layer_hidden_states[:, MERT_BEST_LAYER_IDX[task]])
# print(f'task {task} logits:', logits.shape, 'num class:', num_class)
sorted_idx = torch.argsort(logits, dim = -1, descending=True)[0] # batch =1
sorted_prob,_ = torch.sort(nn.functional.softmax(logits[0], dim=-1), dim=-1, descending=True)
# print(sorted_prob)
# print(sorted_prob.shape)
top_n_show = 5 if num_class >= 5 else num_class
# task_output_texts = task_output_texts + f"TASK {task} output:\n" + "\n".join([str(ID2CLASS[task][str(sorted_idx[idx].item())])+f', probability: {sorted_prob[idx].item():.2%}' for idx in range(top_n_show)]) + '\n'
# task_output_texts = task_output_texts + '----------------------\n'
row_elements = [task]
for idx in range(top_n_show):
print(ID2CLASS[task])
# print('id', str(sorted_idx[idx].item()))
output_class_name = str(ID2CLASS[task][str(sorted_idx[idx].item())])
output_class_name = re.sub(r'^\w+---', '', output_class_name)
output_class_name = re.sub(r'^\w+\/\w+---', '', output_class_name)
# print('output name', output_class_name)
output_prob = f' {sorted_prob[idx].item():.2%}'
row_elements.append(output_class_name+output_prob)
# fill empty elment
for _ in range(5+1 - len(row_elements)):
row_elements.append(' ')
df_objects.append(row_elements)
df = pd.DataFrame(df_objects, columns=['Task', 'Top 1', 'Top 2', 'Top 3', 'Top 4', 'Top 5'])
return df
def convert_audio(inputs, microphone):
if (microphone is not None):
inputs = microphone
df = model_infernce(inputs)
return df
def live_convert_audio(microphone):
if (microphone is not None):
inputs = microphone
df = model_infernce(inputs)
return df
audio_chunked = gr.Interface(
fn=convert_audio,
inputs=inputs,
outputs=outputs,
allow_flagging="never",
title=title,
description=description,
article=article,
examples=audio_examples,
)
live_audio_chunked = gr.Interface(
fn=live_convert_audio,
inputs=live_inputs,
outputs=outputs_live,
allow_flagging="never",
title=title,
description=description,
article=article,
# examples=audio_examples,
live=True,
)
demo = gr.Blocks()
with demo:
gr.TabbedInterface(
[
audio_chunked,
live_audio_chunked,
],
[
"Audio File or Recording",
"Live Streaming Music"
]
)
# demo.queue(concurrency_count=1, max_size=5)
demo.launch(show_api=False)