Spaces:
Running
Running
import os | |
import re | |
import torch | |
import shutil | |
import requests | |
import gradio as gr | |
from piano_transcription_inference import PianoTranscription, load_audio, sample_rate | |
from modelscope import snapshot_download | |
from tempfile import NamedTemporaryFile | |
from pydub.utils import mediainfo | |
from urllib.parse import urlparse | |
from convert import * | |
CACHE_DIR = "./__pycache__" | |
WEIGHTS_PATH = ( | |
snapshot_download("monetjoe/CRNN_note_F1_0.9677_pedal_F1_0.9186") | |
+ "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth" | |
) | |
def clean_cache(cache_dir=CACHE_DIR): | |
if os.path.exists(cache_dir): | |
shutil.rmtree(cache_dir) | |
os.makedirs(cache_dir, exist_ok=True) | |
def get_audio_file_type(file_path): | |
try: | |
# 获取媒体信息 | |
info = mediainfo(file_path) | |
# 返回文件格式 | |
return "." + info["format_name"] | |
except Exception as e: | |
print(f"Error occurred: {e}") | |
return None | |
def download_audio(url, save_path): | |
with NamedTemporaryFile(delete=False, suffix="_temp") as tmp_file: | |
temp_file_path = tmp_file.name | |
# 发送HTTP GET请求并下载内容 | |
response = requests.get(url, stream=True) | |
# 检查请求是否成功 | |
if response.status_code == 200: | |
# 将音频内容写入临时文件 | |
for chunk in response.iter_content(chunk_size=8192): | |
tmp_file.write(chunk) | |
else: | |
print(f"Failed to download file: HTTP {response.status_code}") | |
return "" | |
ext = get_audio_file_type(temp_file_path) | |
full_path = f"{save_path}{ext}" | |
# 重命名临时文件以包含正确的扩展名 | |
shutil.move(temp_file_path, full_path) | |
return full_path | |
def is_url(s: str): | |
try: | |
# 解析字符串 | |
result = urlparse(s) | |
# 检查scheme(如http, https)和netloc(域名) | |
return all([result.scheme, result.netloc]) | |
except: | |
# 如果解析过程中发生异常,则返回False | |
return False | |
def audio2midi(audio_path: str): | |
# Load audio | |
audio, _ = load_audio(audio_path, sr=sample_rate, mono=True) | |
# Transcriptor | |
transcriptor = PianoTranscription( | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
checkpoint_path=WEIGHTS_PATH, | |
) | |
# device: 'cuda' | 'cpu' Transcribe and write out to MIDI file | |
midi_path = f"{CACHE_DIR}/output.mid" | |
# midi_path = audio_path.replace(audio_path.split(".")[-1], "mid") | |
transcriptor.transcribe(audio, midi_path) | |
return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize() | |
def inference(audio_path: str): | |
clean_cache() | |
midi, title = audio2midi(audio_path) | |
xml = midi2xml(midi, title) | |
abc = xml2abc(xml) | |
mxl = xml2mxl(xml) | |
pdf, jpg = xml2jpg(xml) | |
return midi, pdf, xml, mxl, abc, jpg | |
def get_first_integer(input_string): | |
match = re.search(r"\d+", input_string) | |
if match: | |
return str(int(match.group())) | |
else: | |
return "" | |
def infer(audio_url: str): | |
clean_cache() | |
download_path = f"{CACHE_DIR}/output" | |
if is_url(audio_url): | |
if "163" in audio_url and not audio_url.endswith(".mp3"): | |
song_id = get_first_integer(audio_url.split("?id=")[1]) | |
audio_url = f"https://music.163.com/song/media/outer/url?id={song_id}.mp3" | |
print(audio_url) | |
download_path = download_audio(audio_url, download_path) | |
midi, title = audio2midi(download_path) | |
xml = midi2xml(midi, title) | |
abc = xml2abc(xml) | |
mxl = xml2mxl(xml) | |
pdf, jpg = xml2jpg(xml) | |
return download_path, midi, pdf, xml, mxl, abc, jpg | |
with gr.Blocks() as iface: | |
with gr.Tab("Upload mode"): | |
gr.Interface( | |
fn=inference, | |
inputs=gr.Audio( | |
label="Please upload 100% of the audio before clicking submit", | |
type="filepath", | |
), | |
outputs=[ | |
gr.components.File(label="Download MIDI"), | |
gr.components.File(label="Download PDF score"), | |
gr.components.File(label="Download MusicXML"), | |
gr.components.File(label="Download MXL"), | |
gr.Textbox(label="abc notation", show_copy_button=True), | |
gr.Image(label="Staff", type="filepath"), | |
], | |
allow_flagging="never", | |
) | |
with gr.Tab("Direct link mode"): | |
gr.Interface( | |
fn=infer, | |
inputs=gr.Textbox( | |
label="Input audio direct link URL (Netease Cloud Music can directly input non-VIP song page link automatically resolved)" | |
), | |
outputs=[ | |
gr.Audio(label="Download audio", type="filepath"), | |
gr.components.File(label="Download MIDI"), | |
gr.components.File(label="Download PDF score"), | |
gr.components.File(label="Download MusicXML"), | |
gr.components.File(label="Download MXL"), | |
gr.Textbox(label="abc notation", show_copy_button=True), | |
gr.Image(label="Staff", type="filepath"), | |
], | |
allow_flagging="never", | |
) | |
iface.launch() | |