YourMT3 / app.py
mimbres's picture
Update app.py
a96cbe2 verified
raw
history blame
10.2 kB
import spaces
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'amt/src')))
import subprocess
from typing import Tuple, Dict, Literal
from ctypes import ArgumentError
from html_helper import *
from model_helper import *
from pytube import YouTube
import torchaudio
import glob
import gradio as gr
# @title Load Checkpoint
model_name = 'YPTF.MoE+Multi (noPS)' # @param ["YMT3+", "YPTF+Single (noPS)", "YPTF+Multi (PS)", "YPTF.MoE+Multi (noPS)", "YPTF.MoE+Multi (PS)"]
precision = '16'# if torch.cuda.is_available() else '32'# @param ["32", "bf16-mixed", "16"]
project = '2024'
if model_name == "YMT3+":
checkpoint = "notask_all_cross_v6_xk2_amp0811_gm_ext_plus_nops_b72@model.ckpt"
args = [checkpoint, '-p', project, '-pr', precision]
elif model_name == "YPTF+Single (noPS)":
checkpoint = "ptf_all_cross_rebal5_mirst_xk2_edr005_attend_c_full_plus_b100@model.ckpt"
args = [checkpoint, '-p', project, '-enc', 'perceiver-tf', '-ac', 'spec',
'-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF+Multi (PS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256',
'-dec', 'multi-t5', '-nl', '26', '-enc', 'perceiver-tf',
'-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (noPS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
elif model_name == "YPTF.MoE+Multi (PS)":
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b80_ps2@model.ckpt"
args = [checkpoint, '-p', project, '-tk', 'mc13_full_plus_256', '-dec', 'multi-t5',
'-nl', '26', '-enc', 'perceiver-tf', '-sqr', '1', '-ff', 'moe',
'-wf', '4', '-nmoe', '8', '-kmoe', '2', '-act', 'silu', '-epe', 'rope',
'-rp', '1', '-ac', 'spec', '-hop', '300', '-atc', '1', '-pr', precision]
else:
raise ValueError(model_name)
model = load_model_checkpoint(args=args, device="cpu")
model.to("cuda")
# @title GradIO helper
def prepare_media(source_path_or_url: os.PathLike,
source_type: Literal['audio_filepath', 'youtube_url'],
delete_video: bool = True) -> Dict:
"""prepare media from source path or youtube, and return audio info"""
# Get audio_file
if source_type == 'audio_filepath':
audio_file = source_path_or_url
elif source_type == 'youtube_url':
# Download from youtube
try:
# Try PyTube first
yt = YouTube(source_path_or_url)
audio_stream = min(yt.streams.filter(only_audio=True), key=lambda s: s.bitrate)
mp4_file = audio_stream.download(output_path='downloaded') # ./downloaded
audio_file = mp4_file[:-3] + 'mp3'
subprocess.run(['ffmpeg', '-i', mp4_file, '-ac', '1', audio_file])
os.remove(mp4_file)
except Exception as e:
try:
# Try alternative
print(f"Failed with PyTube, error: {e}. Trying yt-dlp...")
audio_file = './downloaded/yt_audio'
subprocess.run(['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
'-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
'--force-overwrites'])
audio_file += '.mp3'
except Exception as e:
print(f"Alternative downloader failed, error: {e}. Please try again later!")
return None
else:
raise ValueError(source_type)
# Create info
info = torchaudio.info(audio_file)
return {
"filepath": audio_file,
"track_name": os.path.basename(audio_file).split('.')[0],
"sample_rate": int(info.sample_rate),
"bits_per_sample": int(info.bits_per_sample),
"num_channels": int(info.num_channels),
"num_frames": int(info.num_frames),
"duration": int(info.num_frames / info.sample_rate),
"encoding": str.lower(info.encoding),
}
@spaces.GPU
def process_audio(audio_filepath):
if audio_filepath is None:
return None
audio_info = prepare_media(audio_filepath, source_type='audio_filepath')
midifile = transcribe(model, audio_info)
midifile = to_data_url(midifile)
return create_html_from_midi(midifile) # html midiplayer
@spaces.GPU
def process_video(youtube_url):
if 'youtu' not in youtube_url:
return None
audio_info = prepare_media(youtube_url, source_type='youtube_url')
midifile = transcribe(model, audio_info)
midifile = to_data_url(midifile)
return create_html_from_midi(midifile) # html midiplayer
def play_video(youtube_url):
if 'youtu' not in youtube_url:
return None
return create_html_youtube_player(youtube_url)
AUDIO_EXAMPLES = glob.glob('examples/*.*', recursive=True)
YOUTUBE_EXAMPLES = ["https://www.youtube.com/watch?v=vMboypSkj3c",
"https://youtu.be/5vJBhdjvVcE?si=s3NFG_SlVju0Iklg",
"https://youtu.be/OXXRoa1U6xU?si=nhJ6lzGenCmk4P7R",
"https://youtu.be/EOJ0wH6h3rE?si=a99k6BnSajvNmXcn",
"https://youtu.be/7mjQooXt28o?si=qqmMxCxwqBlLPDI2",
"https://youtu.be/bnS-HK_lTHA?si=PQLVAab3QHMbv0S3https://youtu.be/zJB0nnOc7bM?si=EA1DN8nHWJcpQWp_",
"https://youtu.be/mIWYTg55h10?si=WkbtKfL6NlNquvT8"]
theme = gr.Theme.from_hub("gradio/dracula_revamped")
theme.text_md = '9px'
theme.text_lg = '11px'
theme.body_background_fill_dark = '#060a1c' #'#372037'# '#a17ba5' #'#73d3ac'
theme.border_color_primary_dark = '#45507328'
theme.block_background_fill_dark = '#3845685c'
theme.body_text_color_dark = 'white'
theme.block_title_text_color_dark = 'black'
theme.body_text_color_subdued_dark = '#e4e9e9'
css = """
.gradio-container {
background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
background-size: 400% 400%;
animation: gradient 15s ease infinite;
height: 100vh;
}
@keyframes gradient {
0% {background-position: 0% 50%;}
50% {background-position: 100% 50%;}
100% {background-position: 0% 50%;}
}
"""
with gr.Blocks(theme=theme, css=css) as demo:
with gr.Row():
with gr.Column(scale=10):
gr.Markdown(
f"""
## 🎶YourMT3+: Multi-instrument Music Transcription with Enhanced Transformer Architectures and Cross-dataset Stem Augmentation
### Model card:
- Model name: `{model_name}`
- Encoder backbone: Perceiver-TF + Mixture of Experts (2/8)
- Decoder backbone: Multi-channel T5-small
- Tokenizer: MT3 tokens with Singing extension
- Dataset: YourMT3 dataset
- Augmentation strategy: Intra-/Cross dataset stem augment, No Pitch-shifting
- FP Precision: BF16-mixed for training, FP16 for inference
#### Caution:
- For acadmic reproduction purpose, we strongly recommend to use [Colab Demo](https://colab.research.google.com/drive/1AgOVEBfZknDkjmSRA7leoa81a2vrnhBG?usp=sharing) with multiple checkpoints.
<div style="display: inline-block;">
<a href="https://arxiv.org/abs/2407.04822">
<img src="https://img.shields.io/badge/arXiv-B31B1B?logo=arxiv&logoColor=fff&style=plastic" alt="arXiv Badge"/>
</a>
</div>
<div style="display: inline-block;">
<a href="https://github.com/mimbres/YourMT3">
<img src="https://img.shields.io/badge/GitHub-181717?logo=github&logoColor=fff&style=plastic" alt="GitHub Badge"/>
</a>
</div>
<div style="display: inline-block;">
<a href="https://huggingface.co/spaces/mimbres/YourMT3">
<img src="https://img.shields.io/badge/Model%20on-🤗-1f425f.svg?style=plastic" alt="Hugging Face Badge"/>
</a>
</div>
""")
with gr.Group():
with gr.Tab("Upload audio"):
# Input
audio_input = gr.Audio(label="Record Audio", type="filepath",
show_share_button=True, show_download_button=True)
# Display examples
gr.Examples(examples=AUDIO_EXAMPLES, inputs=audio_input)
# Submit button
transcribe_audio_button = gr.Button("Transcribe", variant="primary")
# Transcribe
output_tab1 = gr.HTML()
transcribe_audio_button.click(process_audio, inputs=audio_input, outputs=output_tab1)
with gr.Tab("From YouTube"):
with gr.Column(scale=4):
# Input URL
youtube_url = gr.Textbox(label="YouTube Link URL",
placeholder="https://youtu.be/...")
# Display examples
gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_url)
# Play button
play_video_button = gr.Button("Get Audio from YouTube", variant="primary")
# Play youtube
youtube_player = gr.HTML(render=True)
with gr.Column(scale=4):
# Submit button
transcribe_video_button = gr.Button("Transcribe", variant="primary")
with gr.Column(scale=1):
# Transcribe
output_tab2 = gr.HTML(render=True)
# video_output = gr.Text(label="Video Info")
transcribe_video_button.click(process_video, inputs=youtube_url, outputs=output_tab2)
# Play
play_video_button.click(play_video, inputs=youtube_url, outputs=youtube_player)
demo.launch(debug=True)