- .gitignore +2 -0
- app.py +56 -0
- gradio_helper.py +80 -0
- html_helper.py +100 -0
- model_helper.py +161 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
amt/
|
2 |
+
examples/
|
app.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
from gradio_helper import *
|
5 |
+
|
6 |
+
AUDIO_EXAMPLES = glob.glob('/content/examples/*.*', recursive=True)
|
7 |
+
YOUTUBE_EXAMPLES = ["https://www.youtube.com/watch?v=vMboypSkj3c"]
|
8 |
+
|
9 |
+
theme = 'gradio/dracula_revamped' #'Insuz/Mocha' #gr.themes.Soft()
|
10 |
+
with gr.Blocks(theme=theme) as demo:
|
11 |
+
|
12 |
+
with gr.Row():
|
13 |
+
with gr.Column(scale=10):
|
14 |
+
gr.Markdown(
|
15 |
+
"""
|
16 |
+
# YourMT3+: Bridging the Gap in Multi-instrument Music Transcription with Advanced Model Architectures and Cross-dataset Stem Augmentation
|
17 |
+
""")
|
18 |
+
|
19 |
+
with gr.Group():
|
20 |
+
with gr.Tab("Upload audio"):
|
21 |
+
# Input
|
22 |
+
audio_input = gr.Audio(label="Record Audio", type="filepath",
|
23 |
+
show_share_button=True, show_download_button=True)
|
24 |
+
# Display examples
|
25 |
+
gr.Examples(examples=AUDIO_EXAMPLES, inputs=audio_input)
|
26 |
+
# Submit button
|
27 |
+
transcribe_audio_button = gr.Button("Transcribe", variant="primary")
|
28 |
+
# Transcribe
|
29 |
+
output_tab1 = gr.HTML()
|
30 |
+
# audio_output = gr.Text(label="Audio Info")
|
31 |
+
# transcribe_audio_button.click(process_audio, inputs=audio_input, outputs=output_tab1)
|
32 |
+
transcribe_audio_button.click(process_audio, inputs=audio_input, outputs=output_tab1)
|
33 |
+
|
34 |
+
with gr.Tab("From YouTube"):
|
35 |
+
with gr.Row():
|
36 |
+
# Input URL
|
37 |
+
youtube_url = gr.Textbox(label="YouTube Link URL",
|
38 |
+
placeholder="https://youtu.be/...")
|
39 |
+
# Play youtube
|
40 |
+
youtube_player = gr.HTML(render=True)
|
41 |
+
with gr.Row():
|
42 |
+
# Play button
|
43 |
+
play_video_button = gr.Button("Play", variant="primary")
|
44 |
+
# Submit button
|
45 |
+
transcribe_video_button = gr.Button("Transcribe", variant="primary")
|
46 |
+
# Transcribe
|
47 |
+
output_tab2 = gr.HTML(render=True)
|
48 |
+
# video_output = gr.Text(label="Video Info")
|
49 |
+
transcribe_video_button.click(process_video, inputs=youtube_url, outputs=output_tab2)
|
50 |
+
# Play
|
51 |
+
play_video_button.click(play_video, inputs=youtube_url, outputs=youtube_player)
|
52 |
+
|
53 |
+
# Display examples
|
54 |
+
gr.Examples(examples=YOUTUBE_EXAMPLES, inputs=youtube_url)
|
55 |
+
|
56 |
+
demo.launch(debug=True)
|
gradio_helper.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @title GradIO helper
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
import glob
|
5 |
+
from typing import Tuple, Dict, Literal
|
6 |
+
from ctypes import ArgumentError
|
7 |
+
# from google.colab import output
|
8 |
+
|
9 |
+
from model_helper import *
|
10 |
+
from html_helper import *
|
11 |
+
|
12 |
+
from pytube import YouTube
|
13 |
+
import gradio as gr
|
14 |
+
import torchaudio
|
15 |
+
|
16 |
+
def prepare_media(source_path_or_url: os.PathLike,
|
17 |
+
source_type: Literal['audio_filepath', 'youtube_url'],
|
18 |
+
delete_video: bool = True) -> Dict:
|
19 |
+
"""prepare media from source path or youtube, and return audio info"""
|
20 |
+
# Get audio_file
|
21 |
+
if source_type == 'audio_filepath':
|
22 |
+
audio_file = source_path_or_url
|
23 |
+
elif source_type == 'youtube_url':
|
24 |
+
# Download from youtube
|
25 |
+
try:
|
26 |
+
# Try PyTube first
|
27 |
+
yt = YouTube(source_path_or_url)
|
28 |
+
audio_stream = min(yt.streams.filter(only_audio=True), key=lambda s: s.bitrate)
|
29 |
+
mp4_file = audio_stream.download(output_path='downloaded') # ./downloaded
|
30 |
+
audio_file = mp4_file[:-3] + 'mp3'
|
31 |
+
subprocess.run(['ffmpeg', '-i', mp4_file, '-ac', '1', audio_file])
|
32 |
+
os.remove(mp4_file)
|
33 |
+
except Exception as e:
|
34 |
+
try:
|
35 |
+
# Try alternative
|
36 |
+
print(f"Failed with PyTube, error: {e}. Trying yt-dlp...")
|
37 |
+
audio_file = './downloaded/yt_audio'
|
38 |
+
subprocess.run(['yt-dlp', '-x', source_path_or_url, '-f', 'bestaudio',
|
39 |
+
'-o', audio_file, '--audio-format', 'mp3', '--restrict-filenames',
|
40 |
+
'--force-overwrites'])
|
41 |
+
audio_file += '.mp3'
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Alternative downloader failed, error: {e}. Please try again later!")
|
44 |
+
return None
|
45 |
+
else:
|
46 |
+
raise ValueError(source_type)
|
47 |
+
|
48 |
+
# Create info
|
49 |
+
info = torchaudio.info(audio_file)
|
50 |
+
return {
|
51 |
+
"filepath": audio_file,
|
52 |
+
"track_name": os.path.basename(audio_file).split('.')[0],
|
53 |
+
"sample_rate": int(info.sample_rate),
|
54 |
+
"bits_per_sample": int(info.bits_per_sample),
|
55 |
+
"num_channels": int(info.num_channels),
|
56 |
+
"num_frames": int(info.num_frames),
|
57 |
+
"duration": int(info.num_frames / info.sample_rate),
|
58 |
+
"encoding": str.lower(info.encoding),
|
59 |
+
}
|
60 |
+
|
61 |
+
def process_audio(audio_filepath):
|
62 |
+
if audio_filepath is None:
|
63 |
+
return None
|
64 |
+
audio_info = prepare_media(audio_filepath, source_type='audio_filepath')
|
65 |
+
midifile = transcribe(model, audio_info)
|
66 |
+
midifile = to_data_url(midifile)
|
67 |
+
return create_html_from_midi(midifile) # html midiplayer
|
68 |
+
|
69 |
+
def process_video(youtube_url):
|
70 |
+
if 'youtu' not in youtube_url:
|
71 |
+
return None
|
72 |
+
audio_info = prepare_media(youtube_url, source_type='youtube_url')
|
73 |
+
midifile = transcribe(model, audio_info)
|
74 |
+
midifile = to_data_url(midifile)
|
75 |
+
return create_html_from_midi(midifile) # html midiplayer
|
76 |
+
|
77 |
+
def play_video(youtube_url):
|
78 |
+
if 'youtu' not in youtube_url:
|
79 |
+
return None
|
80 |
+
return create_html_youtube_player(youtube_url)
|
html_helper.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @title HTML helper
|
2 |
+
import re
|
3 |
+
import base64
|
4 |
+
def to_data_url(midi_filename):
|
5 |
+
""" This is crucial for Colab/WandB support. Thanks to Scott Hawley!!
|
6 |
+
https://github.com/drscotthawley/midi-player/blob/main/midi_player/midi_player.py
|
7 |
+
|
8 |
+
"""
|
9 |
+
with open(midi_filename, "rb") as f:
|
10 |
+
encoded_string = base64.b64encode(f.read())
|
11 |
+
return 'data:audio/midi;base64,'+encoded_string.decode('utf-8')
|
12 |
+
|
13 |
+
|
14 |
+
def to_youtube_embed_url(video_url):
|
15 |
+
regex = r"(?:https:\/\/)?(?:www\.)?(?:youtube\.com|youtu\.be)\/(?:watch\?v=)?(.+)"
|
16 |
+
return re.sub(regex, r"https://www.youtube.com/embed/\1",video_url)
|
17 |
+
|
18 |
+
|
19 |
+
def create_html_from_midi(midifile):
|
20 |
+
html_template = """
|
21 |
+
<!DOCTYPE html>
|
22 |
+
<html>
|
23 |
+
<head>
|
24 |
+
<title>Awesome MIDI Player</title>
|
25 |
+
<script src="https://cdn.jsdelivr.net/combine/npm/tone@14.7.58,npm/@magenta/music@1.23.1/es6/core.js,npm/focus-visible@5,npm/html-midi-player@1.5.0">
|
26 |
+
</script>
|
27 |
+
<style>
|
28 |
+
/* Background color for the section */
|
29 |
+
#proll {{background-color:transparent}}
|
30 |
+
|
31 |
+
/* Custom player style */
|
32 |
+
#proll midi-player {{
|
33 |
+
display: block;
|
34 |
+
width: inherit;
|
35 |
+
margin: 4px;
|
36 |
+
margin-bottom: 0;
|
37 |
+
}}
|
38 |
+
|
39 |
+
#proll midi-player::part(control-panel) {{
|
40 |
+
background: #D8DAE8;
|
41 |
+
border-radius: 8px 8px 0 0;
|
42 |
+
border: 1px solid #A0A0A0;
|
43 |
+
}}
|
44 |
+
|
45 |
+
/* Custom visualizer style */
|
46 |
+
#proll midi-visualizer .piano-roll-visualizer {{
|
47 |
+
background: #F7FAFA;
|
48 |
+
border-radius: 0 0 8px 8px;
|
49 |
+
border: 1px solid #A0A0A0;
|
50 |
+
margin: 4px;
|
51 |
+
margin-top: 2;
|
52 |
+
overflow: auto;
|
53 |
+
}}
|
54 |
+
|
55 |
+
#proll midi-visualizer svg rect.note {{
|
56 |
+
opacity: 0.6;
|
57 |
+
stroke-width: 2;
|
58 |
+
}}
|
59 |
+
|
60 |
+
#proll midi-visualizer svg rect.note[data-instrument="0"] {{
|
61 |
+
fill: #e22;
|
62 |
+
stroke: #055;
|
63 |
+
}}
|
64 |
+
|
65 |
+
#proll midi-visualizer svg rect.note[data-instrument="2"] {{
|
66 |
+
fill: #2ee;
|
67 |
+
stroke: #055;
|
68 |
+
}}
|
69 |
+
|
70 |
+
#proll midi-visualizer svg rect.note[data-is-drum="true"] {{
|
71 |
+
fill: #888;
|
72 |
+
stroke: #888;
|
73 |
+
}}
|
74 |
+
|
75 |
+
#proll midi-visualizer svg rect.note.active {{
|
76 |
+
opacity: 0.9;
|
77 |
+
stroke: #34384F;
|
78 |
+
}}
|
79 |
+
</style>
|
80 |
+
</head>
|
81 |
+
<body>
|
82 |
+
<div>
|
83 |
+
<a href="{midifile}" target="_blank">Download MIDI</a> <br>
|
84 |
+
<section id="proll">
|
85 |
+
<midi-player src="{midifile}" sound-font="https://storage.googleapis.com/magentadata/js/soundfonts/sgm_plus" visualizer="#proll midi-visualizer">
|
86 |
+
</midi-player>
|
87 |
+
<midi-visualizer src="{midifile}">
|
88 |
+
</midi-visualizer>
|
89 |
+
</section>
|
90 |
+
</div>
|
91 |
+
</body>
|
92 |
+
</html>
|
93 |
+
""".format(midifile=midifile)
|
94 |
+
html = f"""<iframe style="width: 100%; height: 400px; overflow:auto" srcdoc='{html_template}'></iframe>"""
|
95 |
+
return html
|
96 |
+
|
97 |
+
def create_html_youtube_player(youtube_url):
|
98 |
+
youtube_url = to_youtube_embed_url(youtube_url)
|
99 |
+
html = f"""<iframe width="560" height="315" src='{youtube_url}' title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>"""
|
100 |
+
return html
|
model_helper.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @title Model helper
|
2 |
+
%cd /content/amt/src
|
3 |
+
from collections import Counter
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from model.init_train import initialize_trainer, update_config
|
9 |
+
from utils.task_manager import TaskManager
|
10 |
+
from config.vocabulary import drum_vocab_presets
|
11 |
+
from utils.utils import str2bool
|
12 |
+
from utils.utils import Timer
|
13 |
+
from utils.audio import slice_padded_array
|
14 |
+
from utils.note2event import mix_notes
|
15 |
+
from utils.event2note import merge_zipped_note_events_and_ties_to_notes
|
16 |
+
from utils.utils import write_model_output_as_midi, write_err_cnt_as_json
|
17 |
+
from model.ymt3 import YourMT3
|
18 |
+
|
19 |
+
|
20 |
+
def load_model_checkpoint(args=None):
|
21 |
+
parser = argparse.ArgumentParser(description="YourMT3")
|
22 |
+
# General
|
23 |
+
parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.')
|
24 |
+
parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name')
|
25 |
+
parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.')
|
26 |
+
parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.')
|
27 |
+
parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.')
|
28 |
+
parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.')
|
29 |
+
# Model configurations
|
30 |
+
parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py')
|
31 |
+
parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.")
|
32 |
+
parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.")
|
33 |
+
parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None")
|
34 |
+
parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.")
|
35 |
+
parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.')
|
36 |
+
parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False')
|
37 |
+
parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False')
|
38 |
+
parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.')
|
39 |
+
parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False')
|
40 |
+
parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")')
|
41 |
+
parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.")
|
42 |
+
parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.")
|
43 |
+
parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.')
|
44 |
+
parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.')
|
45 |
+
# Perceiver-TF configurations
|
46 |
+
parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
|
47 |
+
parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
|
48 |
+
parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.')
|
49 |
+
parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.')
|
50 |
+
parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
|
51 |
+
parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
|
52 |
+
parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.')
|
53 |
+
parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.')
|
54 |
+
parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
|
55 |
+
parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
|
56 |
+
parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.')
|
57 |
+
parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.')
|
58 |
+
parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.')
|
59 |
+
parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.')
|
60 |
+
parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.')
|
61 |
+
parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.')
|
62 |
+
# Decoder configurations
|
63 |
+
parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.')
|
64 |
+
parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.')
|
65 |
+
# Task and Evaluation configurations
|
66 |
+
parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=mt3_full_plus). See config/task.py for more options.')
|
67 |
+
parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation vocabulary (default=None). If None, default vocabulary of the data preset will be used.')
|
68 |
+
parser.add_argument('-edv', '--eval-drum-vocab', type=str, default=None, help='evaluation vocabulary for drum (default=None). If None, default vocabulary of the data preset will be used.')
|
69 |
+
parser.add_argument('-etk', '--eval-subtask-key', type=str, default='default', help='evaluation subtask key (default=default). See config/task.py for more options.')
|
70 |
+
parser.add_argument('-t', '--onset-tolerance', type=float, default=0.05, help='onset tolerance (default=0.05).')
|
71 |
+
parser.add_argument('-os', '--test-octave-shift', type=str2bool, default=False, help='test optimal octave shift (default=False). True or False')
|
72 |
+
parser.add_argument('-w', '--write-model-output', type=str2bool, default=True, help='write model test output to file (default=False). True or False')
|
73 |
+
# Trainer configurations
|
74 |
+
parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}')
|
75 |
+
parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp')
|
76 |
+
parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)')
|
77 |
+
parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")')
|
78 |
+
parser.add_argument('-wb', '--wandb-mode', type=str, default="disabled", help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.')
|
79 |
+
# Debug
|
80 |
+
parser.add_argument('-debug', '--debug-mode', type=str2bool, default=False, help='debug mode (default=False). True or False')
|
81 |
+
parser.add_argument('-tps', '--test-pitch-shift', type=int, default=None, help='use pitch shift when testing. debug-purpose only. (default=None). semitone in int.')
|
82 |
+
args = parser.parse_args(args)
|
83 |
+
# yapf: enable
|
84 |
+
if torch.__version__ >= "1.13":
|
85 |
+
torch.set_float32_matmul_precision("high")
|
86 |
+
args.epochs = None
|
87 |
+
|
88 |
+
# Initialize and update config
|
89 |
+
_, _, dir_info, shared_cfg = initialize_trainer(args, stage='test')
|
90 |
+
shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='test')
|
91 |
+
|
92 |
+
if args.eval_drum_vocab != None: # override eval_drum_vocab
|
93 |
+
eval_drum_vocab = drum_vocab_presets[args.eval_drum_vocab]
|
94 |
+
|
95 |
+
# Initialize task manager
|
96 |
+
tm = TaskManager(task_name=args.task,
|
97 |
+
max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"]),
|
98 |
+
debug_mode=args.debug_mode)
|
99 |
+
print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}")
|
100 |
+
|
101 |
+
# Use GPU if available
|
102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
103 |
+
|
104 |
+
# Model
|
105 |
+
model = YourMT3(
|
106 |
+
audio_cfg=audio_cfg,
|
107 |
+
model_cfg=model_cfg,
|
108 |
+
shared_cfg=shared_cfg,
|
109 |
+
optimizer=None,
|
110 |
+
task_manager=tm, # tokenizer is a member of task_manager
|
111 |
+
eval_subtask_key=args.eval_subtask_key,
|
112 |
+
write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
|
113 |
+
).to(device)
|
114 |
+
checkpoint = torch.load(dir_info["last_ckpt_path"])
|
115 |
+
state_dict = checkpoint['state_dict']
|
116 |
+
new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
|
117 |
+
model.load_state_dict(new_state_dict, strict=False)
|
118 |
+
return model.eval()
|
119 |
+
|
120 |
+
|
121 |
+
def transcribe(model, audio_info):
|
122 |
+
t = Timer()
|
123 |
+
|
124 |
+
# Converting Audio
|
125 |
+
t.start()
|
126 |
+
audio, sr = torchaudio.load(uri=audio_info['filepath'])
|
127 |
+
audio = torch.mean(audio, dim=0).unsqueeze(0)
|
128 |
+
audio = torchaudio.functional.resample(audio, sr, model.audio_cfg['sample_rate'])
|
129 |
+
audio_segments = slice_padded_array(audio, model.audio_cfg['input_frames'], model.audio_cfg['input_frames'])
|
130 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
131 |
+
audio_segments = torch.from_numpy(audio_segments.astype('float32')).to(device).unsqueeze(1) # (n_seg, 1, seg_sz)
|
132 |
+
t.stop(); t.print_elapsed_time("converting audio");
|
133 |
+
|
134 |
+
# Inference
|
135 |
+
t.start()
|
136 |
+
pred_token_arr, _ = model.inference_file(bsz=8, audio_segments=audio_segments)
|
137 |
+
t.stop(); t.print_elapsed_time("model inference");
|
138 |
+
|
139 |
+
# Post-processing
|
140 |
+
t.start()
|
141 |
+
num_channels = model.task_manager.num_decoding_channels
|
142 |
+
n_items = audio_segments.shape[0]
|
143 |
+
start_secs_file = [model.audio_cfg['input_frames'] * i / model.audio_cfg['sample_rate'] for i in range(n_items)]
|
144 |
+
pred_notes_in_file = []
|
145 |
+
n_err_cnt = Counter()
|
146 |
+
for ch in range(num_channels):
|
147 |
+
pred_token_arr_ch = [arr[:, ch, :] for arr in pred_token_arr] # (B, L)
|
148 |
+
zipped_note_events_and_tie, list_events, ne_err_cnt = model.task_manager.detokenize_list_batches(
|
149 |
+
pred_token_arr_ch, start_secs_file, return_events=True)
|
150 |
+
pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie)
|
151 |
+
pred_notes_in_file.append(pred_notes_ch)
|
152 |
+
n_err_cnt += n_err_cnt_ch
|
153 |
+
pred_notes = mix_notes(pred_notes_in_file) # This is the mixed notes from all channels
|
154 |
+
|
155 |
+
# Write MIDI
|
156 |
+
write_model_output_as_midi(pred_notes, '/content/',
|
157 |
+
audio_info['track_name'], model.midi_output_inverse_vocab)
|
158 |
+
t.stop(); t.print_elapsed_time("post processing");
|
159 |
+
midifile = os.path.join('/content/model_output/', audio_info['track_name'] + '.mid')
|
160 |
+
assert os.path.exists(midifile)
|
161 |
+
return midifile
|