Spaces:
Running
Running
MuGeminorum
commited on
Commit
•
653dc95
1
Parent(s):
37e9aba
upl base
Browse files- .gitattributes +12 -11
- .gitignore +8 -0
- README.md +4 -4
- app.py +215 -0
- conda.txt +5 -0
- config.py +19 -0
- render.py +73 -0
- requirements.txt +9 -0
- utils.py +388 -0
.gitattributes
CHANGED
@@ -1,35 +1,36 @@
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
|
11 |
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
13 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
17 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
20 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
22 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
|
24 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.db* filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.ark* filter=lfs diff=lfs merge=lfs -text
|
30 |
+
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
31 |
+
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
32 |
+
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.sf3 filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.AppImage filter=lfs diff=lfs merge=lfs -textlibnss3.so filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/*
|
2 |
+
output/*
|
3 |
+
rename.sh
|
4 |
+
test.py
|
5 |
+
gpt2-abcmusic/*
|
6 |
+
*.pth
|
7 |
+
tmp/*
|
8 |
+
mscore3/*
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
title: HoyoGPT
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
1 |
---
|
2 |
title: HoyoGPT
|
3 |
+
emoji: 🎹
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.7.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
app.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import shutil
|
6 |
+
import argparse
|
7 |
+
import gradio as gr
|
8 |
+
from utils import *
|
9 |
+
from config import *
|
10 |
+
from render import *
|
11 |
+
from music21 import converter
|
12 |
+
from transformers import GPT2Config
|
13 |
+
import warnings
|
14 |
+
warnings.filterwarnings('ignore')
|
15 |
+
|
16 |
+
|
17 |
+
def abc_to_midi(abc_content, output_midi_path):
|
18 |
+
# 解析 ABC 格式的乐谱
|
19 |
+
score = converter.parse(abc_content)
|
20 |
+
|
21 |
+
# 将乐谱保存为 MIDI 文件
|
22 |
+
score.write('midi', fp=output_midi_path)
|
23 |
+
return output_midi_path
|
24 |
+
|
25 |
+
|
26 |
+
def get_args(parser):
|
27 |
+
parser.add_argument('-num_tunes', type=int, default=1,
|
28 |
+
help='the number of independently computed returned tunes')
|
29 |
+
parser.add_argument('-max_patch', type=int, default=128,
|
30 |
+
help='integer to define the maximum length in tokens of each tune')
|
31 |
+
parser.add_argument('-top_p', type=float, default=0.8,
|
32 |
+
help='float to define the tokens that are within the sample operation of text generation')
|
33 |
+
parser.add_argument('-top_k', type=int, default=8,
|
34 |
+
help='integer to define the tokens that are within the sample operation of text generation')
|
35 |
+
parser.add_argument('-temperature', type=float, default=1.2,
|
36 |
+
help='the temperature of the sampling operation')
|
37 |
+
parser.add_argument('-seed', type=int, default=None,
|
38 |
+
help='seed for randomstate')
|
39 |
+
parser.add_argument('-show_control_code', type=bool,
|
40 |
+
default=True, help='whether to show control code')
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
return args
|
44 |
+
|
45 |
+
|
46 |
+
def generate_abc(args, region):
|
47 |
+
patchilizer = Patchilizer()
|
48 |
+
|
49 |
+
patch_config = GPT2Config(
|
50 |
+
num_hidden_layers=PATCH_NUM_LAYERS,
|
51 |
+
max_length=PATCH_LENGTH,
|
52 |
+
max_position_embeddings=PATCH_LENGTH,
|
53 |
+
vocab_size=1
|
54 |
+
)
|
55 |
+
|
56 |
+
char_config = GPT2Config(
|
57 |
+
num_hidden_layers=CHAR_NUM_LAYERS,
|
58 |
+
max_length=PATCH_SIZE,
|
59 |
+
max_position_embeddings=PATCH_SIZE,
|
60 |
+
vocab_size=128
|
61 |
+
)
|
62 |
+
|
63 |
+
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
|
64 |
+
|
65 |
+
filename = WEIGHT_PATH
|
66 |
+
|
67 |
+
if os.path.exists(filename):
|
68 |
+
print(f"Weights already exist at '{filename}'. Loading...")
|
69 |
+
|
70 |
+
else:
|
71 |
+
download()
|
72 |
+
|
73 |
+
checkpoint = torch.load(filename, map_location=torch.device('cpu'))
|
74 |
+
model.load_state_dict(checkpoint['model'])
|
75 |
+
model = model.to(device)
|
76 |
+
model.eval()
|
77 |
+
|
78 |
+
prompt = template(region)
|
79 |
+
|
80 |
+
tunes = ""
|
81 |
+
num_tunes = args.num_tunes
|
82 |
+
max_patch = args.max_patch
|
83 |
+
top_p = args.top_p
|
84 |
+
top_k = args.top_k
|
85 |
+
temperature = args.temperature
|
86 |
+
seed = args.seed
|
87 |
+
show_control_code = args.show_control_code
|
88 |
+
|
89 |
+
print(" HYPERPARAMETERS ".center(60, "#"), '\n')
|
90 |
+
args = vars(args)
|
91 |
+
|
92 |
+
for key in args.keys():
|
93 |
+
print(f'{key}: {str(args[key])}')
|
94 |
+
|
95 |
+
print('\n', " OUTPUT TUNES ".center(60, "#"))
|
96 |
+
|
97 |
+
start_time = time.time()
|
98 |
+
|
99 |
+
for i in range(num_tunes):
|
100 |
+
tune = f"X:{str(i + 1)}\n{prompt}"
|
101 |
+
lines = re.split(r'(\n)', tune)
|
102 |
+
tune = ""
|
103 |
+
skip = False
|
104 |
+
for line in lines:
|
105 |
+
if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
|
106 |
+
if not skip:
|
107 |
+
print(line, end="")
|
108 |
+
tune += line
|
109 |
+
|
110 |
+
skip = False
|
111 |
+
|
112 |
+
else:
|
113 |
+
skip = True
|
114 |
+
|
115 |
+
input_patches = torch.tensor(
|
116 |
+
[patchilizer.encode(prompt, add_special_patches=True)[:-1]],
|
117 |
+
device=device
|
118 |
+
)
|
119 |
+
|
120 |
+
if tune == "":
|
121 |
+
tokens = None
|
122 |
+
|
123 |
+
else:
|
124 |
+
prefix = patchilizer.decode(input_patches[0])
|
125 |
+
remaining_tokens = prompt[len(prefix):]
|
126 |
+
tokens = torch.tensor(
|
127 |
+
[patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens],
|
128 |
+
device=device
|
129 |
+
)
|
130 |
+
|
131 |
+
while input_patches.shape[1] < max_patch:
|
132 |
+
predicted_patch, seed = model.generate(
|
133 |
+
input_patches,
|
134 |
+
tokens,
|
135 |
+
top_p=top_p,
|
136 |
+
top_k=top_k,
|
137 |
+
temperature=temperature,
|
138 |
+
seed=seed
|
139 |
+
)
|
140 |
+
tokens = None
|
141 |
+
|
142 |
+
if predicted_patch[0] != patchilizer.eos_token_id:
|
143 |
+
next_bar = patchilizer.decode([predicted_patch])
|
144 |
+
|
145 |
+
if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
|
146 |
+
print(next_bar, end="")
|
147 |
+
tune += next_bar
|
148 |
+
|
149 |
+
if next_bar == "":
|
150 |
+
break
|
151 |
+
|
152 |
+
next_bar = remaining_tokens+next_bar
|
153 |
+
remaining_tokens = ""
|
154 |
+
|
155 |
+
predicted_patch = torch.tensor(
|
156 |
+
patchilizer.bar2patch(next_bar),
|
157 |
+
device=device
|
158 |
+
).unsqueeze(0)
|
159 |
+
|
160 |
+
input_patches = torch.cat(
|
161 |
+
[input_patches, predicted_patch.unsqueeze(0)],
|
162 |
+
dim=1
|
163 |
+
)
|
164 |
+
|
165 |
+
else:
|
166 |
+
break
|
167 |
+
|
168 |
+
tunes += f"{tune}\n\n"
|
169 |
+
print("\n")
|
170 |
+
|
171 |
+
print("Generation time: {:.2f} seconds".format(time.time() - start_time))
|
172 |
+
create_dir('./tmp')
|
173 |
+
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
|
174 |
+
out_midi = abc_to_midi(tunes, f'./tmp/[{region}]{timestamp}.mid')
|
175 |
+
add_path()
|
176 |
+
png_file = midi2png(out_midi)
|
177 |
+
wav_file = midi2wav(out_midi)
|
178 |
+
|
179 |
+
return tunes, out_midi, png_file, wav_file
|
180 |
+
|
181 |
+
|
182 |
+
def inference(region):
|
183 |
+
if os.path.exists('./tmp'):
|
184 |
+
shutil.rmtree('./tmp')
|
185 |
+
|
186 |
+
parser = argparse.ArgumentParser()
|
187 |
+
args = get_args(parser)
|
188 |
+
return generate_abc(args, region)
|
189 |
+
|
190 |
+
|
191 |
+
with gr.Blocks() as demo:
|
192 |
+
with gr.Row():
|
193 |
+
with gr.Column():
|
194 |
+
region_opt = gr.Dropdown(
|
195 |
+
choices=[
|
196 |
+
'Mondstadt', 'Liyue', 'Inazuma', 'Sumeru', 'Fontaine'
|
197 |
+
],
|
198 |
+
value='Liyue',
|
199 |
+
label='Region'
|
200 |
+
)
|
201 |
+
gen_btn = gr.Button("Generate")
|
202 |
+
|
203 |
+
with gr.Column():
|
204 |
+
wav_output = gr.Audio(label='Audio', type='filepath')
|
205 |
+
dld_midi = gr.components.File(label="Download MIDI")
|
206 |
+
abc_output = gr.TextArea(label='abc score')
|
207 |
+
img_score = gr.Image(label='Staff', type='filepath')
|
208 |
+
|
209 |
+
gen_btn.click(
|
210 |
+
inference,
|
211 |
+
inputs=region_opt,
|
212 |
+
outputs=[abc_output, dld_midi, img_score, wav_output]
|
213 |
+
)
|
214 |
+
|
215 |
+
demo.launch(share=True)
|
conda.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python=3.10
|
2 |
+
pytorch=1.12.1
|
3 |
+
torchvision=0.13.1
|
4 |
+
torchaudio=0.12.1
|
5 |
+
cudatoolkit=11.3.1
|
config.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PATCH_LENGTH = 128 # Patch Length
|
2 |
+
PATCH_SIZE = 32 # Patch Size
|
3 |
+
|
4 |
+
PATCH_NUM_LAYERS = 9 # Number of layers in the encoder
|
5 |
+
CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
|
6 |
+
|
7 |
+
# Number of epochs to train for (if early stopping doesn't intervene)
|
8 |
+
NUM_EPOCHS = 5 # 32
|
9 |
+
LEARNING_RATE = 5e-5 # Learning rate for the optimizer
|
10 |
+
# Batch size for patch during training, 0 for full context
|
11 |
+
PATCH_SAMPLING_BATCH_SIZE = 0
|
12 |
+
LOAD_FROM_CHECKPOINT = True # Whether to load weights from a checkpoint
|
13 |
+
# Whether to share weights between the encoder and decoder
|
14 |
+
SHARE_WEIGHTS = False
|
15 |
+
WEIGHT_URL = 'https://huggingface.co/MuGeminorum/hoyoGPT/resolve/main/weights.pth'
|
16 |
+
ZH_WEIGHT_URL = 'https://www.modelscope.cn/api/v1/models/MuGeminorum/hoyoGPT/repo?Revision=master&FilePath=weights.pth'
|
17 |
+
WEIGHT_PATH = 'weights.pth'
|
18 |
+
LOG_PATH = 'logs.txt'
|
19 |
+
PROMPT_PATH = 'prompt.txt'
|
render.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import subprocess
|
4 |
+
from PIL import Image
|
5 |
+
from pdf2image import convert_from_path
|
6 |
+
from utils import download
|
7 |
+
|
8 |
+
|
9 |
+
def add_path():
|
10 |
+
# """
|
11 |
+
# 将指定目录添加到 LD_LIBRARY_PATH 环境变量中,并在当前 Python 进程中生效。
|
12 |
+
|
13 |
+
# Parameters:
|
14 |
+
# - directory_path (str): 要添加的目录路径。
|
15 |
+
# """
|
16 |
+
# dir_path = os.path.join(os.getcwd(), 'lib')
|
17 |
+
# # 获取当前环境变量的值
|
18 |
+
# current_path = os.environ.get("LD_LIBRARY_PATH", "")
|
19 |
+
|
20 |
+
# # 将目录路径添加到 LD_LIBRARY_PATH 中
|
21 |
+
# new_path = f"{current_path}:{dir_path}"
|
22 |
+
|
23 |
+
# # 设置 LD_LIBRARY_PATH 环境变量,以便在当前 Python 进程中生效
|
24 |
+
# os.environ["LD_LIBRARY_PATH"] = new_path
|
25 |
+
os.environ['QT_QPA_PLATFORM'] = 'offscreen'
|
26 |
+
|
27 |
+
|
28 |
+
if sys.platform.startswith('linux'):
|
29 |
+
apkname = 'MuseScore.AppImage'
|
30 |
+
extra_dir = 'squashfs-root'
|
31 |
+
download(
|
32 |
+
filename=apkname,
|
33 |
+
url='https://cdn.jsdelivr.net/musescore/v4.2.0/MuseScore-4.2.0.233521125-x86_64.AppImage'
|
34 |
+
)
|
35 |
+
if not os.path.exists(extra_dir):
|
36 |
+
subprocess.run(['chmod', '+x', f'./{apkname}'])
|
37 |
+
subprocess.run([f'./{apkname}', '--appimage-extract'])
|
38 |
+
|
39 |
+
mscore = f'./{extra_dir}/AppRun'
|
40 |
+
|
41 |
+
else:
|
42 |
+
mscore = "D:/Program Files/MuseScore 3/bin/MuseScore3.exe"
|
43 |
+
|
44 |
+
|
45 |
+
def midi2wav(mid_file: str):
|
46 |
+
wav_file = mid_file.replace('.mid', '.wav')
|
47 |
+
command = [mscore, "-o", wav_file, mid_file]
|
48 |
+
result = subprocess.run(command)
|
49 |
+
print(result)
|
50 |
+
return wav_file
|
51 |
+
|
52 |
+
|
53 |
+
def pdf_to_img(pdf_path: str):
|
54 |
+
output_path = pdf_path.replace('.pdf', '.jpg')
|
55 |
+
images = convert_from_path(pdf_path)
|
56 |
+
combined_image = Image.new(
|
57 |
+
'RGB', (images[0].width, sum(image.height for image in images))
|
58 |
+
)
|
59 |
+
y_offset = 0
|
60 |
+
for image in images:
|
61 |
+
combined_image.paste(image, (0, y_offset))
|
62 |
+
y_offset += image.height
|
63 |
+
|
64 |
+
combined_image.save(output_path)
|
65 |
+
return output_path
|
66 |
+
|
67 |
+
|
68 |
+
def midi2png(mid_file: str):
|
69 |
+
pdf_score = mid_file.replace('.mid', '.pdf')
|
70 |
+
command = [mscore, "-o", pdf_score, mid_file]
|
71 |
+
result = subprocess.run(command)
|
72 |
+
print(result)
|
73 |
+
return pdf_to_img(pdf_score)
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.18.0
|
2 |
+
samplings==0.1.7
|
3 |
+
unidecode
|
4 |
+
music21
|
5 |
+
autopep8
|
6 |
+
pillow==9.4.0
|
7 |
+
gradio
|
8 |
+
pdf2image
|
9 |
+
torch
|
utils.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
from config import *
|
6 |
+
from tqdm import tqdm
|
7 |
+
from unidecode import unidecode
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel
|
10 |
+
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
|
11 |
+
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
device = torch.device("cuda")
|
14 |
+
else:
|
15 |
+
device = torch.device("cpu")
|
16 |
+
|
17 |
+
|
18 |
+
def template(region):
|
19 |
+
return f'''A:{region}
|
20 |
+
S:2
|
21 |
+
B:9
|
22 |
+
E:4
|
23 |
+
B:9
|
24 |
+
L:1/8
|
25 |
+
M:3/4
|
26 |
+
K:D
|
27 |
+
de |"D"'''
|
28 |
+
|
29 |
+
|
30 |
+
def create_dir(dir_path):
|
31 |
+
if not os.path.exists(dir_path):
|
32 |
+
os.makedirs(dir_path)
|
33 |
+
|
34 |
+
|
35 |
+
def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
|
36 |
+
import time
|
37 |
+
import requests
|
38 |
+
try:
|
39 |
+
response = requests.get(url, stream=True)
|
40 |
+
total_size = int(response.headers.get('content-length', 0))
|
41 |
+
chunk_size = 1024
|
42 |
+
|
43 |
+
with open(filename, 'wb') as file, tqdm(
|
44 |
+
desc=f"Downloading weights to '{filename}'...",
|
45 |
+
total=total_size,
|
46 |
+
unit='B',
|
47 |
+
unit_scale=True,
|
48 |
+
unit_divisor=1024,
|
49 |
+
) as bar:
|
50 |
+
for data in response.iter_content(chunk_size=chunk_size):
|
51 |
+
size = file.write(data)
|
52 |
+
bar.update(size)
|
53 |
+
|
54 |
+
except ConnectionError as e:
|
55 |
+
print(f"Error: {e}")
|
56 |
+
time.sleep(3)
|
57 |
+
download(filename, ZH_WEIGHT_URL)
|
58 |
+
|
59 |
+
|
60 |
+
class Patchilizer:
|
61 |
+
"""
|
62 |
+
A class for converting music bars to patches and vice versa.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self):
|
66 |
+
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
|
67 |
+
self.regexPattern = f"({'|'.join(map(re.escape, self.delimiters))})"
|
68 |
+
self.pad_token_id = 0
|
69 |
+
self.bos_token_id = 1
|
70 |
+
self.eos_token_id = 2
|
71 |
+
|
72 |
+
def split_bars(self, body):
|
73 |
+
"""
|
74 |
+
Split a body of music into individual bars.
|
75 |
+
"""
|
76 |
+
bars = re.split(self.regexPattern, ''.join(body))
|
77 |
+
bars = list(filter(None, bars))
|
78 |
+
# remove empty strings
|
79 |
+
if bars[0] in self.delimiters:
|
80 |
+
bars[1] = bars[0] + bars[1]
|
81 |
+
bars = bars[1:]
|
82 |
+
|
83 |
+
bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
|
84 |
+
return bars
|
85 |
+
|
86 |
+
def bar2patch(self, bar, patch_size=PATCH_SIZE):
|
87 |
+
"""
|
88 |
+
Convert a bar into a patch of specified length.
|
89 |
+
"""
|
90 |
+
patch = [self.bos_token_id] + \
|
91 |
+
[ord(c) for c in bar] + [self.eos_token_id]
|
92 |
+
patch = patch[:patch_size]
|
93 |
+
patch += [self.pad_token_id] * (patch_size - len(patch))
|
94 |
+
return patch
|
95 |
+
|
96 |
+
def patch2bar(self, patch):
|
97 |
+
"""
|
98 |
+
Convert a patch into a bar.
|
99 |
+
"""
|
100 |
+
return ''.join(chr(idx) if idx > self.eos_token_id else '' for idx in patch if idx != self.eos_token_id)
|
101 |
+
|
102 |
+
def encode(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=False):
|
103 |
+
"""
|
104 |
+
Encode music into patches of specified length.
|
105 |
+
"""
|
106 |
+
lines = unidecode(abc_code).split('\n')
|
107 |
+
lines = list(filter(None, lines)) # remove empty lines
|
108 |
+
|
109 |
+
body = ""
|
110 |
+
patches = []
|
111 |
+
|
112 |
+
for line in lines:
|
113 |
+
if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%score')):
|
114 |
+
if body:
|
115 |
+
bars = self.split_bars(body)
|
116 |
+
patches.extend(
|
117 |
+
self.bar2patch(bar + '\n' if idx == len(bars) - 1 else bar, patch_size) for idx, bar in enumerate(bars)
|
118 |
+
)
|
119 |
+
body = ""
|
120 |
+
|
121 |
+
patches.append(self.bar2patch(line + '\n', patch_size))
|
122 |
+
|
123 |
+
else:
|
124 |
+
body += line + '\n'
|
125 |
+
|
126 |
+
if body:
|
127 |
+
patches.extend(
|
128 |
+
self.bar2patch(bar, patch_size) for bar in self.split_bars(body)
|
129 |
+
)
|
130 |
+
|
131 |
+
if add_special_patches:
|
132 |
+
bos_patch = [self.bos_token_id] * \
|
133 |
+
(patch_size-1) + [self.eos_token_id]
|
134 |
+
eos_patch = [self.bos_token_id] + \
|
135 |
+
[self.eos_token_id] * (patch_size-1)
|
136 |
+
patches = [bos_patch] + patches + [eos_patch]
|
137 |
+
|
138 |
+
return patches[:patch_length]
|
139 |
+
|
140 |
+
def decode(self, patches):
|
141 |
+
"""
|
142 |
+
Decode patches into music.
|
143 |
+
"""
|
144 |
+
return ''.join(self.patch2bar(patch) for patch in patches)
|
145 |
+
|
146 |
+
|
147 |
+
class PatchLevelDecoder(PreTrainedModel):
|
148 |
+
"""
|
149 |
+
An Patch-level Decoder model for generating patch features in an auto-regressive manner.
|
150 |
+
It inherits PreTrainedModel from transformers.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(self, config):
|
154 |
+
super().__init__(config)
|
155 |
+
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
|
156 |
+
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
157 |
+
self.base = GPT2Model(config)
|
158 |
+
|
159 |
+
def forward(self, patches: torch.Tensor) -> torch.Tensor:
|
160 |
+
"""
|
161 |
+
The forward pass of the patch-level decoder model.
|
162 |
+
:param patches: the patches to be encoded
|
163 |
+
:return: the encoded patches
|
164 |
+
"""
|
165 |
+
patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
|
166 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
|
167 |
+
patches = self.patch_embedding(patches.to(self.device))
|
168 |
+
|
169 |
+
return self.base(inputs_embeds=patches)
|
170 |
+
|
171 |
+
|
172 |
+
class CharLevelDecoder(PreTrainedModel):
|
173 |
+
"""
|
174 |
+
A Char-level Decoder model for generating the characters within each bar patch sequentially.
|
175 |
+
It inherits PreTrainedModel from transformers.
|
176 |
+
"""
|
177 |
+
|
178 |
+
def __init__(self, config):
|
179 |
+
super().__init__(config)
|
180 |
+
self.pad_token_id = 0
|
181 |
+
self.bos_token_id = 1
|
182 |
+
self.eos_token_id = 2
|
183 |
+
self.base = GPT2LMHeadModel(config)
|
184 |
+
|
185 |
+
def forward(self, encoded_patches: torch.Tensor, target_patches: torch.Tensor, patch_sampling_batch_size: int):
|
186 |
+
"""
|
187 |
+
The forward pass of the char-level decoder model.
|
188 |
+
:param encoded_patches: the encoded patches
|
189 |
+
:param target_patches: the target patches
|
190 |
+
:return: the decoded patches
|
191 |
+
"""
|
192 |
+
# preparing the labels for model training
|
193 |
+
target_masks = target_patches == self.pad_token_id
|
194 |
+
labels = target_patches.clone().masked_fill_(target_masks, -100)
|
195 |
+
|
196 |
+
# masking the labels for model training
|
197 |
+
target_masks = torch.ones_like(labels)
|
198 |
+
target_masks = target_masks.masked_fill_(labels == -100, 0)
|
199 |
+
|
200 |
+
# select patches
|
201 |
+
if patch_sampling_batch_size != 0 and patch_sampling_batch_size < target_patches.shape[0]:
|
202 |
+
indices = list(range(len(target_patches)))
|
203 |
+
random.shuffle(indices)
|
204 |
+
selected_indices = sorted(indices[:patch_sampling_batch_size])
|
205 |
+
|
206 |
+
target_patches = target_patches[selected_indices, :]
|
207 |
+
target_masks = target_masks[selected_indices, :]
|
208 |
+
encoded_patches = encoded_patches[selected_indices, :]
|
209 |
+
labels = labels[selected_indices, :]
|
210 |
+
|
211 |
+
# get input embeddings
|
212 |
+
inputs_embeds = torch.nn.functional.embedding(
|
213 |
+
target_patches,
|
214 |
+
self.base.transformer.wte.weight
|
215 |
+
)
|
216 |
+
|
217 |
+
# concatenate the encoded patches with the input embeddings
|
218 |
+
inputs_embeds = torch.cat(
|
219 |
+
(encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]),
|
220 |
+
dim=1
|
221 |
+
)
|
222 |
+
|
223 |
+
return self.base(
|
224 |
+
inputs_embeds=inputs_embeds,
|
225 |
+
attention_mask=target_masks,
|
226 |
+
labels=labels
|
227 |
+
)
|
228 |
+
|
229 |
+
def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
|
230 |
+
"""
|
231 |
+
The generate function for generating a patch based on the encoded patch and already generated tokens.
|
232 |
+
:param encoded_patch: the encoded patch
|
233 |
+
:param tokens: already generated tokens in the patch
|
234 |
+
:return: the probability distribution of next token
|
235 |
+
"""
|
236 |
+
encoded_patch = encoded_patch.reshape(1, 1, -1)
|
237 |
+
tokens = tokens.reshape(1, -1)
|
238 |
+
|
239 |
+
# Get input embeddings
|
240 |
+
tokens = torch.nn.functional.embedding(
|
241 |
+
tokens,
|
242 |
+
self.base.transformer.wte.weight
|
243 |
+
)
|
244 |
+
|
245 |
+
# Concatenate the encoded patch with the input embeddings
|
246 |
+
tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
|
247 |
+
|
248 |
+
# Get output from model
|
249 |
+
outputs = self.base(inputs_embeds=tokens)
|
250 |
+
|
251 |
+
# Get probabilities of next token
|
252 |
+
probs = torch.nn.functional.softmax(
|
253 |
+
outputs.logits.squeeze(0)[-1],
|
254 |
+
dim=-1
|
255 |
+
)
|
256 |
+
|
257 |
+
return probs
|
258 |
+
|
259 |
+
|
260 |
+
class TunesFormer(PreTrainedModel):
|
261 |
+
"""
|
262 |
+
TunesFormer is a hierarchical music generation model based on bar patching.
|
263 |
+
It includes a patch-level decoder and a character-level decoder.
|
264 |
+
It inherits PreTrainedModel from transformers.
|
265 |
+
"""
|
266 |
+
|
267 |
+
def __init__(self, encoder_config, decoder_config, share_weights=False):
|
268 |
+
super().__init__(encoder_config)
|
269 |
+
self.pad_token_id = 0
|
270 |
+
self.bos_token_id = 1
|
271 |
+
self.eos_token_id = 2
|
272 |
+
if share_weights:
|
273 |
+
max_layers = max(
|
274 |
+
encoder_config.num_hidden_layers,
|
275 |
+
decoder_config.num_hidden_layers
|
276 |
+
)
|
277 |
+
|
278 |
+
max_context_size = max(
|
279 |
+
encoder_config.max_length,
|
280 |
+
decoder_config.max_length
|
281 |
+
)
|
282 |
+
|
283 |
+
max_position_embeddings = max(
|
284 |
+
encoder_config.max_position_embeddings,
|
285 |
+
decoder_config.max_position_embeddings
|
286 |
+
)
|
287 |
+
|
288 |
+
encoder_config.num_hidden_layers = max_layers
|
289 |
+
encoder_config.max_length = max_context_size
|
290 |
+
encoder_config.max_position_embeddings = max_position_embeddings
|
291 |
+
decoder_config.num_hidden_layers = max_layers
|
292 |
+
decoder_config.max_length = max_context_size
|
293 |
+
decoder_config.max_position_embeddings = max_position_embeddings
|
294 |
+
|
295 |
+
self.patch_level_decoder = PatchLevelDecoder(encoder_config)
|
296 |
+
self.char_level_decoder = CharLevelDecoder(decoder_config)
|
297 |
+
|
298 |
+
if share_weights:
|
299 |
+
self.patch_level_decoder.base = self.char_level_decoder.base.transformer
|
300 |
+
|
301 |
+
def forward(self, patches: torch.Tensor, patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE):
|
302 |
+
"""
|
303 |
+
The forward pass of the TunesFormer model.
|
304 |
+
:param patches: the patches to be both encoded and decoded
|
305 |
+
:return: the decoded patches
|
306 |
+
"""
|
307 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
308 |
+
encoded_patches = self.patch_level_decoder(
|
309 |
+
patches)["last_hidden_state"]
|
310 |
+
|
311 |
+
return self.char_level_decoder(encoded_patches.squeeze(0)[:-1, :], patches.squeeze(0)[1:, :], patch_sampling_batch_size)
|
312 |
+
|
313 |
+
def generate(
|
314 |
+
self,
|
315 |
+
patches: torch.Tensor,
|
316 |
+
tokens: torch.Tensor,
|
317 |
+
top_p: float = 1,
|
318 |
+
top_k: int = 0,
|
319 |
+
temperature: float = 1,
|
320 |
+
seed: int = None
|
321 |
+
):
|
322 |
+
"""
|
323 |
+
The generate function for generating patches based on patches.
|
324 |
+
:param patches: the patches to be encoded
|
325 |
+
:return: the generated patches
|
326 |
+
"""
|
327 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
328 |
+
encoded_patches = self.patch_level_decoder(
|
329 |
+
patches)["last_hidden_state"]
|
330 |
+
|
331 |
+
if tokens == None:
|
332 |
+
tokens = torch.tensor([self.bos_token_id], device=self.device)
|
333 |
+
|
334 |
+
generated_patch = []
|
335 |
+
random.seed(seed)
|
336 |
+
|
337 |
+
while True:
|
338 |
+
if seed != None:
|
339 |
+
n_seed = random.randint(0, 1000000)
|
340 |
+
random.seed(n_seed)
|
341 |
+
|
342 |
+
else:
|
343 |
+
n_seed = None
|
344 |
+
|
345 |
+
prob = self.char_level_decoder.generate(
|
346 |
+
encoded_patches[0][-1],
|
347 |
+
tokens
|
348 |
+
).cpu().detach().numpy()
|
349 |
+
|
350 |
+
prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
|
351 |
+
prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
|
352 |
+
|
353 |
+
token = temperature_sampling(
|
354 |
+
prob,
|
355 |
+
temperature=temperature,
|
356 |
+
seed=n_seed
|
357 |
+
)
|
358 |
+
|
359 |
+
generated_patch.append(token)
|
360 |
+
if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
|
361 |
+
break
|
362 |
+
|
363 |
+
else:
|
364 |
+
tokens = torch.cat(
|
365 |
+
(tokens, torch.tensor([token], device=self.device)),
|
366 |
+
dim=0
|
367 |
+
)
|
368 |
+
|
369 |
+
return generated_patch, n_seed
|
370 |
+
|
371 |
+
|
372 |
+
class PatchilizedData(Dataset):
|
373 |
+
def __init__(self, items, patchilizer):
|
374 |
+
self.texts = []
|
375 |
+
|
376 |
+
for item in tqdm(items):
|
377 |
+
text = item['control code'] + \
|
378 |
+
"\n".join(item['abc notation'].split('\n')[1:])
|
379 |
+
input_patch = patchilizer.encode(text, add_special_patches=True)
|
380 |
+
input_patch = torch.tensor(input_patch)
|
381 |
+
if torch.sum(input_patch) != 0:
|
382 |
+
self.texts.append(input_patch)
|
383 |
+
|
384 |
+
def __len__(self):
|
385 |
+
return len(self.texts)
|
386 |
+
|
387 |
+
def __getitem__(self, idx):
|
388 |
+
return self.texts[idx]
|