import os import re import torch import shutil import requests import gradio as gr from piano_transcription_inference import PianoTranscription, load_audio, sample_rate from modelscope import snapshot_download from tempfile import NamedTemporaryFile from pydub.utils import mediainfo from urllib.parse import urlparse from convert import * CACHE_DIR = "./__pycache__" WEIGHTS_PATH = ( snapshot_download("monetjoe/CRNN_note_F1_0.9677_pedal_F1_0.9186") + "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth" ) def clean_cache(cache_dir=CACHE_DIR): if os.path.exists(cache_dir): shutil.rmtree(cache_dir) os.mkdir(cache_dir) def get_audio_file_type(file_path): try: # 获取媒体信息 info = mediainfo(file_path) # 返回文件格式 return "." + info["format_name"] except Exception as e: print(f"Error occurred: {e}") return None def download_audio(url, save_path): with NamedTemporaryFile(delete=False, suffix="_temp") as tmp_file: temp_file_path = tmp_file.name # 发送HTTP GET请求并下载内容 response = requests.get(url, stream=True) # 检查请求是否成功 if response.status_code == 200: # 将音频内容写入临时文件 for chunk in response.iter_content(chunk_size=8192): tmp_file.write(chunk) else: print(f"Failed to download file: HTTP {response.status_code}") return "" ext = get_audio_file_type(temp_file_path) full_path = f"{save_path}{ext}" # 重命名临时文件以包含正确的扩展名 shutil.move(temp_file_path, full_path) return full_path def is_url(s: str): try: # 解析字符串 result = urlparse(s) # 检查scheme(如http, https)和netloc(域名) return all([result.scheme, result.netloc]) except: # 如果解析过程中发生异常,则返回False return False def audio2midi(audio_path: str): # Load audio audio, _ = load_audio(audio_path, sr=sample_rate, mono=True) # Transcriptor transcriptor = PianoTranscription( device="cuda" if torch.cuda.is_available() else "cpu", checkpoint_path=WEIGHTS_PATH, ) # device: 'cuda' | 'cpu' Transcribe and write out to MIDI file midi_path = f"{CACHE_DIR}/output.mid" # midi_path = audio_path.replace(audio_path.split(".")[-1], "mid") transcriptor.transcribe(audio, midi_path) return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize() def inference(audio_path: str): clean_cache() midi, title = audio2midi(audio_path) xml = midi2xml(midi, title) abc = xml2abc(xml) mxl = xml2mxl(xml) pdf, jpg = xml2jpg(xml) return midi, pdf, xml, mxl, abc, jpg def get_first_integer(input_string): match = re.search(r"\d+", input_string) if match: return str(int(match.group())) else: return "" def infer(audio_url: str): clean_cache() download_path = f"{CACHE_DIR}/output" if is_url(audio_url): if "163" in audio_url and not audio_url.endswith(".mp3"): song_id = get_first_integer(audio_url.split("?id=")[1]) audio_url = f"https://music.163.com/song/media/outer/url?id={song_id}.mp3" print(audio_url) download_path = download_audio(audio_url, download_path) midi, title = audio2midi(download_path) xml = midi2xml(midi, title) abc = xml2abc(xml) mxl = xml2mxl(xml) pdf, jpg = xml2jpg(xml) return download_path, midi, pdf, xml, mxl, abc, jpg with gr.Blocks() as iface: with gr.Tab("上传模式"): gr.Interface( fn=inference, inputs=gr.Audio(label="请上传音频100%后再点提交", type="filepath"), outputs=[ gr.components.File(label="下载 MIDI"), gr.components.File(label="下载 PDF 乐谱"), gr.components.File(label="下载 MusicXML"), gr.components.File(label="下载 MXL"), gr.Textbox(label="abc 乐谱", show_copy_button=True), gr.Image(label="五线谱", type="filepath"), ], allow_flagging="never", ) with gr.Tab("直链模式"): gr.Interface( fn=infer, inputs=gr.Textbox( label="输入音频直链URL(网易云音乐可直接输入非VIP歌曲页面链接自动解析)" ), outputs=[ gr.Audio(label="下载音频", type="filepath"), gr.components.File(label="下载 MIDI"), gr.components.File(label="下载 PDF 乐谱"), gr.components.File(label="下载 MusicXML"), gr.components.File(label="下载 MXL"), gr.Textbox(label="abc 乐谱", show_copy_button=True), gr.Image(label="五线谱", type="filepath"), ], allow_flagging="never", ) iface.launch()