File size: 5,169 Bytes
0487c7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a531320
0487c7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f401070
0487c7d
 
 
 
 
 
 
 
a531320
 
 
 
0487c7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import re
import torch
import shutil
import requests
import gradio as gr
from piano_transcription_inference import PianoTranscription, load_audio, sample_rate
from modelscope import snapshot_download
from tempfile import NamedTemporaryFile
from pydub.utils import mediainfo
from urllib.parse import urlparse
from convert import *

CACHE_DIR = "./__pycache__"
WEIGHTS_PATH = (
    snapshot_download("monetjoe/CRNN_note_F1_0.9677_pedal_F1_0.9186")
    + "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth"
)


def clean_cache(cache_dir=CACHE_DIR):
    if os.path.exists(cache_dir):
        shutil.rmtree(cache_dir)

    os.makedirs(cache_dir, exist_ok=True)


def get_audio_file_type(file_path):
    try:
        # 获取媒体信息
        info = mediainfo(file_path)
        # 返回文件格式
        return "." + info["format_name"]
    except Exception as e:
        print(f"Error occurred: {e}")
        return None


def download_audio(url, save_path):
    with NamedTemporaryFile(delete=False, suffix="_temp") as tmp_file:
        temp_file_path = tmp_file.name

        # 发送HTTP GET请求并下载内容
        response = requests.get(url, stream=True)

        # 检查请求是否成功
        if response.status_code == 200:
            # 将音频内容写入临时文件
            for chunk in response.iter_content(chunk_size=8192):
                tmp_file.write(chunk)
        else:
            print(f"Failed to download file: HTTP {response.status_code}")
            return ""

    ext = get_audio_file_type(temp_file_path)
    full_path = f"{save_path}{ext}"

    # 重命名临时文件以包含正确的扩展名
    shutil.move(temp_file_path, full_path)

    return full_path


def is_url(s: str):
    try:
        # 解析字符串
        result = urlparse(s)
        # 检查scheme(如http, https)和netloc(域名)
        return all([result.scheme, result.netloc])
    except:
        # 如果解析过程中发生异常,则返回False
        return False


def audio2midi(audio_path: str):
    # Load audio
    audio, _ = load_audio(audio_path, sr=sample_rate, mono=True)
    # Transcriptor
    transcriptor = PianoTranscription(
        device="cuda" if torch.cuda.is_available() else "cpu",
        checkpoint_path=WEIGHTS_PATH,
    )
    # device: 'cuda' | 'cpu' Transcribe and write out to MIDI file
    midi_path = f"{CACHE_DIR}/output.mid"
    # midi_path = audio_path.replace(audio_path.split(".")[-1], "mid")
    transcriptor.transcribe(audio, midi_path)
    return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize()


def inference(audio_path: str):
    clean_cache()
    midi, title = audio2midi(audio_path)
    xml = midi2xml(midi, title)
    abc = xml2abc(xml)
    mxl = xml2mxl(xml)
    pdf, jpg = xml2jpg(xml)

    return midi, pdf, xml, mxl, abc, jpg


def get_first_integer(input_string):
    match = re.search(r"\d+", input_string)
    if match:
        return str(int(match.group()))
    else:
        return ""


def infer(audio_url: str):
    clean_cache()
    download_path = f"{CACHE_DIR}/output"
    if is_url(audio_url):
        if "163" in audio_url and not audio_url.endswith(".mp3"):
            song_id = get_first_integer(audio_url.split("?id=")[1])
            audio_url = f"https://music.163.com/song/media/outer/url?id={song_id}.mp3"
            print(audio_url)

        download_path = download_audio(audio_url, download_path)

    midi, title = audio2midi(download_path)
    xml = midi2xml(midi, title)
    mxl = xml2mxl(xml)
    pdf, jpg = xml2jpg(xml)
    abc = xml2abc(xml)

    return download_path, midi, pdf, xml, mxl, abc, jpg


with gr.Blocks() as iface:
    with gr.Tab("Upload mode"):
        gr.Interface(
            fn=inference,
            inputs=gr.Audio(
                label="Please upload 100% of the audio before clicking submit",
                type="filepath",
            ),
            outputs=[
                gr.components.File(label="Download MIDI"),
                gr.components.File(label="Download PDF score"),
                gr.components.File(label="Download MusicXML"),
                gr.components.File(label="Download MXL"),
                gr.Textbox(label="abc notation", show_copy_button=True),
                gr.Image(label="Staff", type="filepath"),
            ],
            allow_flagging="never",
        )

    with gr.Tab("Direct link mode"):
        gr.Interface(
            fn=infer,
            inputs=gr.Textbox(
                label="Input audio direct link URL (Netease Cloud Music can directly input non-VIP song page link automatically resolved)"
            ),
            outputs=[
                gr.Audio(label="Download audio", type="filepath"),
                gr.components.File(label="Download MIDI"),
                gr.components.File(label="Download PDF score"),
                gr.components.File(label="Download MusicXML"),
                gr.components.File(label="Download MXL"),
                gr.Textbox(label="abc notation", show_copy_button=True),
                gr.Image(label="Staff", type="filepath"),
            ],
            allow_flagging="never",
        )

iface.launch()