asigalov61's picture
Update app.py
c16687b verified
raw
history blame
5.78 kB
import os.path
import time as reqtime
import datetime
from pytz import timezone
import torch
import spaces
import gradio as gr
import random
from midi_to_colab_audio import midi_to_colab_audio
import TMIDIX
import matplotlib.pyplot as plt
from inference import PianoTranscription
from config import sample_rate
from utilities import load_audio
# =================================================================================================
@spaces.GPU
def TranscribePianoAudio(input_file):
print('=' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
start_time = reqtime.time()
print('=' * 70)
fn = os.path.basename(input_file)
fn1 = fn.split('.')[0]
out_mid = fn1+'.mid'
print('-' * 70)
print('Input file name:', fn)
print('-' * 70)
print('Loading audio...')
# Load audio
(audio, _) = load_audio(input_file, sr=sample_rate, mono=True)
print('Done!')
print('-' * 70)
print('Loading transcriptor..')
# Transcriptor
transcriptor = PianoTranscription(device='cuda') # 'cuda' | 'cpu'
print('Done!')
print('-' * 70)
print('Transcribing...')
transcribed_dict = transcriptor.transcribe(audio, out_mid)
print('Done!')
print('-' * 70)
#===============================================================================
raw_score = TMIDIX.midi2single_track_ms_score(out_mid)
#===============================================================================
# Enhanced score notes
print(raw_score)
escore = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
#==================================================================
print('=' * 70)
print('Number of transcribed notes:', len(escore))
print('Sample trascribed MIDI events', escore[:5])
print('=' * 70)
print('Done!')
print('=' * 70)
#===============================================================================
print('Rendering results...')
print('=' * 70)
audio = midi_to_colab_audio(out_mid,
soundfont_path=soundfont,
sample_rate=16000,
volume_scale=10,
output_for_gradio=True
)
print('Done!')
print('=' * 70)
#========================================================
output_midi_title = str(fn1)
output_midi_summary = str(escore[:3])
output_midi = str(out_mid)
output_audio = (16000, audio)
output_plot = TMIDIX.plot_ms_SONG(escore, plot_title=output_midi_title, return_plt=True)
print('Output MIDI file name:', output_midi)
print('Output MIDI title:', output_midi_title)
print('Output MIDI summary:', output_midi_summary)
print('=' * 70)
#========================================================
print('-' * 70)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('-' * 70)
print('Req execution time:', (reqtime.time() - start_time), 'sec')
return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot
# =================================================================================================
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
print('=' * 70)
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>ByteDance Solo Piano Audio to MIDI Transcription</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Transcribe any Solo Piano WAV or MP3 audio to MIDI</h1>")
gr.Markdown(
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.ByteDance-Solo-Piano-Adio-to-MIDI-Transcription&style=flat)\n\n"
"This is a ByteDance Solo Piano Audio to MIDI Transcription Model\n\n"
"Check out [ByteDance Solo Piano Audio to MIDI Transcription](https://github.com/asigalov61/piano_transcription_inference) on GitHub!\n\n"
"[Open In Colab]"
"(https://colab.research.google.com/github/asigalov61/tegridy-tools/blob/main/tegridy-tools/notebooks/ByteDance_Piano_Transcription.ipynb)"
" for faster execution and endless transcription"
)
gr.Markdown("## Upload your Solo Piano WAV or MP3 audio or select a sample example audio file")
input_audio = gr.File(label="Input Solo Piano WAV or MP3 Audio File", file_types=[".wav", ".mp3"])
run_btn = gr.Button("transcribe", variant="primary")
gr.Markdown("## Generation results")
output_midi_title = gr.Textbox(label="Output MIDI title")
output_midi_summary = gr.Textbox(label="Output MIDI summary")
output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio")
output_plot = gr.Plot(label="Output MIDI score plot")
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
run_event = run_btn.click(TranscribePianoAudio, [input_audio],
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
gr.Examples(
[["cut_liszt.mp3"]
],
[input_audio],
[output_midi_title, output_midi_summary, output_midi, output_audio, output_plot],
TranscribePianoAudio,
cache_examples=True,
)
app.queue().launch()