juancopi81's picture
Change to gradio blocks, add radio for selecting the model
ed28ae4
raw
history blame
No virus
2.34 kB
import os
os.system("python3 -m pip install -e .")
import gradio as gr
import note_seq
from inferencemodel import InferenceModel
from utils import upload_audio
SAMPLE_RATE = 16000
SF2_PATH = "SGM-v2.01-Sal-Guit-Bass-V1.3.sf2"
# Start inference model
inference_model = InferenceModel("/home/user/app/checkpoints/mt3/", "mt3")
current_model = "mt3"
def change_model(model):
global current_model
if model == current_model:
return
global inference_model
inference_model = InferenceModel("/home/user/app/checkpoints/mt3/", model)
current_model = model
def inference(audio):
with open(audio, "rb") as fd:
contents = fd.read()
audio = upload_audio(contents,sample_rate=16000)
est_ns = inference_model(audio)
note_seq.sequence_proto_to_midi_file(est_ns, "./transcribed.mid")
return "./transcribed.mid"
title = "Transcribe music from YouTube videos using Transformers."
description = """"
Gradio demo for Music Transcription with Transformers Read more in the links below.
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.03017' target='_blank'>MT3: Multi-Task Multitrack Music Transcription</a> | <a href='https://github.com/magenta/mt3' target='_blank'>Github Repo</a></p>"
# Create a block object
demo = gr.Blocks()
# Use your Block object as a context
with demo:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>"
+ title
+ "</h1>")
gr.Markdown(description)
with gr.Box():
with gr.Row():
gr.Markdown("<h2>Select your model</h2>")
gr.Markdown("""
The ismir2021 model transcribes piano only, with note velocities.
The mt3 model transcribes multiple simultaneous instruments, but without velocities."
""")
model = gr.Radio(
["mt3", "ismir2021"], label="What kind of model you want to use?"
)
model.change(fn=change_model, inputs=model, outputs=[])
demo.launch()
""" gr.Interface(
inference,
gr.inputs.Audio(type="filepath", label="Input"),
[gr.outputs.File(label="Output")],
title=title,
description=description,
article=article,
examples=examples,
).launch().queue() """