Spaces:
Running
Running
Monan Zhou
commited on
Commit
•
6c48757
1
Parent(s):
dc7b3e1
Upload 5 files
Browse files- app.py +267 -237
- config.py +9 -19
- convert.py +88 -93
- requirements.txt +8 -8
- utils.py +377 -394
app.py
CHANGED
@@ -1,237 +1,267 @@
|
|
1 |
-
import re
|
2 |
-
import os
|
3 |
-
import time
|
4 |
-
import torch
|
5 |
-
import shutil
|
6 |
-
import argparse
|
7 |
-
import gradio as gr
|
8 |
-
from
|
9 |
-
from
|
10 |
-
from convert import
|
11 |
-
from
|
12 |
-
import
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
"
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
tokens =
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
def inference(region):
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import shutil
|
6 |
+
import argparse
|
7 |
+
import gradio as gr
|
8 |
+
from config import *
|
9 |
+
from utils import Patchilizer, TunesFormer, DEVICE
|
10 |
+
from convert import abc_to_midi, abc_to_musicxml, musicxml_to_mxl, mxl2jpg, midi2wav
|
11 |
+
from modelscope import snapshot_download
|
12 |
+
from transformers import GPT2Config
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
# 模型下载
|
18 |
+
MODEL_DIR = snapshot_download("MuGeminorum/hoyoGPT")
|
19 |
+
|
20 |
+
|
21 |
+
def get_args(parser: argparse.ArgumentParser):
|
22 |
+
parser.add_argument(
|
23 |
+
"-num_tunes",
|
24 |
+
type=int,
|
25 |
+
default=1,
|
26 |
+
help="the number of independently computed returned tunes",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"-max_patch",
|
30 |
+
type=int,
|
31 |
+
default=128,
|
32 |
+
help="integer to define the maximum length in tokens of each tune",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"-top_p",
|
36 |
+
type=float,
|
37 |
+
default=0.8,
|
38 |
+
help="float to define the tokens that are within the sample operation of text generation",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"-top_k",
|
42 |
+
type=int,
|
43 |
+
default=8,
|
44 |
+
help="integer to define the tokens that are within the sample operation of text generation",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"-temperature",
|
48 |
+
type=float,
|
49 |
+
default=1.2,
|
50 |
+
help="the temperature of the sampling operation",
|
51 |
+
)
|
52 |
+
parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
|
53 |
+
parser.add_argument(
|
54 |
+
"-show_control_code",
|
55 |
+
type=bool,
|
56 |
+
default=False,
|
57 |
+
help="whether to show control code",
|
58 |
+
)
|
59 |
+
args = parser.parse_args()
|
60 |
+
|
61 |
+
return args
|
62 |
+
|
63 |
+
|
64 |
+
def generate_abc(args, epochs: str, region: str):
|
65 |
+
patchilizer = Patchilizer()
|
66 |
+
|
67 |
+
patch_config = GPT2Config(
|
68 |
+
num_hidden_layers=PATCH_NUM_LAYERS,
|
69 |
+
max_length=PATCH_LENGTH,
|
70 |
+
max_position_embeddings=PATCH_LENGTH,
|
71 |
+
vocab_size=1,
|
72 |
+
)
|
73 |
+
|
74 |
+
char_config = GPT2Config(
|
75 |
+
num_hidden_layers=CHAR_NUM_LAYERS,
|
76 |
+
max_length=PATCH_SIZE,
|
77 |
+
max_position_embeddings=PATCH_SIZE,
|
78 |
+
vocab_size=128,
|
79 |
+
)
|
80 |
+
|
81 |
+
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
|
82 |
+
filename = f"{MODEL_DIR}/{epochs}/weights.pth"
|
83 |
+
checkpoint = torch.load(filename, map_location=torch.device("cpu"))
|
84 |
+
model.load_state_dict(checkpoint["model"])
|
85 |
+
model = model.to(DEVICE)
|
86 |
+
model.eval()
|
87 |
+
|
88 |
+
prompt = f"A:{region}\n"
|
89 |
+
|
90 |
+
tunes = ""
|
91 |
+
num_tunes = args.num_tunes
|
92 |
+
max_patch = args.max_patch
|
93 |
+
top_p = args.top_p
|
94 |
+
top_k = args.top_k
|
95 |
+
temperature = args.temperature
|
96 |
+
seed = args.seed
|
97 |
+
show_control_code = args.show_control_code
|
98 |
+
|
99 |
+
print(" HYPERPARAMETERS ".center(60, "#"), "\n")
|
100 |
+
arg_dict: dict = vars(args)
|
101 |
+
|
102 |
+
for key in arg_dict.keys():
|
103 |
+
print(f"{key}: {str(arg_dict[key])}")
|
104 |
+
|
105 |
+
print("\n", " OUTPUT TUNES ".center(60, "#"))
|
106 |
+
|
107 |
+
start_time = time.time()
|
108 |
+
for i in range(num_tunes):
|
109 |
+
title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
|
110 |
+
tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
|
111 |
+
lines = re.split(r"(\n)", tune)
|
112 |
+
tune = ""
|
113 |
+
skip = False
|
114 |
+
for line in lines:
|
115 |
+
if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
|
116 |
+
if not skip:
|
117 |
+
print(line, end="")
|
118 |
+
tune += line
|
119 |
+
|
120 |
+
skip = False
|
121 |
+
|
122 |
+
else:
|
123 |
+
skip = True
|
124 |
+
|
125 |
+
input_patches = torch.tensor(
|
126 |
+
[patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=DEVICE
|
127 |
+
)
|
128 |
+
|
129 |
+
if tune == "":
|
130 |
+
tokens = None
|
131 |
+
|
132 |
+
else:
|
133 |
+
prefix = patchilizer.decode(input_patches[0])
|
134 |
+
remaining_tokens = prompt[len(prefix) :]
|
135 |
+
tokens = torch.tensor(
|
136 |
+
[patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
|
137 |
+
device=DEVICE,
|
138 |
+
)
|
139 |
+
|
140 |
+
while input_patches.shape[1] < max_patch:
|
141 |
+
predicted_patch, seed = model.generate(
|
142 |
+
input_patches,
|
143 |
+
tokens,
|
144 |
+
top_p=top_p,
|
145 |
+
top_k=top_k,
|
146 |
+
temperature=temperature,
|
147 |
+
seed=seed,
|
148 |
+
)
|
149 |
+
tokens = None
|
150 |
+
|
151 |
+
if predicted_patch[0] != patchilizer.eos_token_id:
|
152 |
+
next_bar = patchilizer.decode([predicted_patch])
|
153 |
+
|
154 |
+
if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
|
155 |
+
print(next_bar, end="")
|
156 |
+
tune += next_bar
|
157 |
+
|
158 |
+
if next_bar == "":
|
159 |
+
break
|
160 |
+
|
161 |
+
next_bar = remaining_tokens + next_bar
|
162 |
+
remaining_tokens = ""
|
163 |
+
|
164 |
+
predicted_patch = torch.tensor(
|
165 |
+
patchilizer.bar2patch(next_bar), device=DEVICE
|
166 |
+
).unsqueeze(0)
|
167 |
+
|
168 |
+
input_patches = torch.cat(
|
169 |
+
[input_patches, predicted_patch.unsqueeze(0)], dim=1
|
170 |
+
)
|
171 |
+
|
172 |
+
else:
|
173 |
+
break
|
174 |
+
|
175 |
+
tunes += f"{tune}\n\n"
|
176 |
+
print("\n")
|
177 |
+
|
178 |
+
print("Generation time: {:.2f} seconds".format(time.time() - start_time))
|
179 |
+
os.makedirs(TEMP_DIR, exist_ok=True)
|
180 |
+
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
|
181 |
+
try:
|
182 |
+
out_midi = abc_to_midi(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.mid")
|
183 |
+
out_xml = abc_to_musicxml(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
|
184 |
+
out_mxl = musicxml_to_mxl(f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
|
185 |
+
pdf_file, jpg_file = mxl2jpg(out_mxl)
|
186 |
+
wav_file = midi2wav(out_midi)
|
187 |
+
|
188 |
+
return tunes, out_midi, pdf_file, out_xml, out_mxl, jpg_file, wav_file
|
189 |
+
|
190 |
+
except Exception as e:
|
191 |
+
print(f"Invalid abc generated: {e}, retrying...")
|
192 |
+
return generate_abc(args, epochs, region)
|
193 |
+
|
194 |
+
|
195 |
+
def inference(epochs: str, region: str):
|
196 |
+
Teyvat = {
|
197 |
+
"蒙德": "Mondstadt",
|
198 |
+
"璃月": "Liyue",
|
199 |
+
"稻妻": "Inazuma",
|
200 |
+
"须弥": "Sumeru",
|
201 |
+
"枫丹": "Fontaine",
|
202 |
+
}
|
203 |
+
|
204 |
+
if os.path.exists(TEMP_DIR):
|
205 |
+
shutil.rmtree(TEMP_DIR)
|
206 |
+
|
207 |
+
parser = argparse.ArgumentParser()
|
208 |
+
args = get_args(parser)
|
209 |
+
return generate_abc(args, epochs, Teyvat[region])
|
210 |
+
|
211 |
+
|
212 |
+
if __name__ == "__main__":
|
213 |
+
with gr.Blocks() as demo:
|
214 |
+
gr.Markdown(
|
215 |
+
"""<center>欢迎使用此创空间,此创空间由bilibili <a href="https://space.bilibili.com/30620472">@亦真亦幻Studio</a> 基于 Tunesformer 开源项目制作,完全免费。</center>"""
|
216 |
+
)
|
217 |
+
with gr.Row():
|
218 |
+
with gr.Column():
|
219 |
+
weight_opt = gr.Dropdown(
|
220 |
+
choices=["5", "15"],
|
221 |
+
value="15",
|
222 |
+
label="模型选择(epochs)",
|
223 |
+
)
|
224 |
+
region_opt = gr.Dropdown(
|
225 |
+
choices=["蒙德", "璃月", "稻妻", "须弥", "枫丹"],
|
226 |
+
value="蒙德",
|
227 |
+
label="地区风格",
|
228 |
+
)
|
229 |
+
gen_btn = gr.Button("生成")
|
230 |
+
gr.Markdown(
|
231 |
+
"""
|
232 |
+
<center>
|
233 |
+
当前模型还在调试中,由于训练数据由 MIDI 转换而来,存在大量谱面不规范问题,导致很多生成结果有谱面不规范等问题。<br>
|
234 |
+
|
235 |
+
计划在原神主线杀青后,所有国家地区角色全部开放后,二创音乐会齐全且样本均衡,届时重新微调模型并添加现实风格筛选辅助游戏各国家输出把关,以提升输出区分度与质量。
|
236 |
+
|
237 |
+
数据来源:<a href="https://musescore.org">MuseScore</a><br>
|
238 |
+
Tag 嵌入数据来源:<a href="https://genshin-impact.fandom.com/wiki/Genshin_Impact_Wiki">Genshin Impact Wiki | Fandom</a><br>
|
239 |
+
模型基础:<a href="https://github.com/sander-wood/tunesformer">Tunesformer</a>
|
240 |
+
|
241 |
+
注:崩铁方面数据工程正在运作中,未来也希望随主线杀青而基线化。</center>"""
|
242 |
+
)
|
243 |
+
|
244 |
+
with gr.Column():
|
245 |
+
wav_output = gr.Audio(label="音频", type="filepath")
|
246 |
+
dld_midi = gr.components.File(label="下载 MIDI")
|
247 |
+
pdf_score = gr.components.File(label="下载 PDF 乐谱")
|
248 |
+
dld_xml = gr.components.File(label="下载 MusicXML")
|
249 |
+
dld_mxl = gr.components.File(label="下载 MXL")
|
250 |
+
abc_output = gr.Textbox(label="abc notation", show_copy_button=True)
|
251 |
+
img_score = gr.Image(label="五线谱", type="filepath")
|
252 |
+
|
253 |
+
gen_btn.click(
|
254 |
+
inference,
|
255 |
+
inputs=[weight_opt, region_opt],
|
256 |
+
outputs=[
|
257 |
+
abc_output,
|
258 |
+
dld_midi,
|
259 |
+
pdf_score,
|
260 |
+
dld_xml,
|
261 |
+
dld_mxl,
|
262 |
+
img_score,
|
263 |
+
wav_output,
|
264 |
+
],
|
265 |
+
)
|
266 |
+
|
267 |
+
demo.launch()
|
config.py
CHANGED
@@ -1,19 +1,9 @@
|
|
1 |
-
PATCH_LENGTH = 128
|
2 |
-
PATCH_SIZE = 32
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
#
|
8 |
-
|
9 |
-
|
10 |
-
# Batch size for patch during training, 0 for full context
|
11 |
-
PATCH_SAMPLING_BATCH_SIZE = 0
|
12 |
-
LOAD_FROM_CHECKPOINT = True # Whether to load weights from a checkpoint
|
13 |
-
# Whether to share weights between the encoder and decoder
|
14 |
-
SHARE_WEIGHTS = False
|
15 |
-
WEIGHT_URL = 'https://huggingface.co/MuGeminorum/hoyoGPT/resolve/main/weights.pth'
|
16 |
-
ZH_WEIGHT_URL = 'https://www.modelscope.cn/api/v1/models/MuGeminorum/hoyoGPT/repo?Revision=master&FilePath=weights.pth'
|
17 |
-
WEIGHT_PATH = 'weights.pth'
|
18 |
-
LOG_PATH = 'logs.txt'
|
19 |
-
PROMPT_PATH = 'prompt.txt'
|
|
|
1 |
+
PATCH_LENGTH = 128 # Patch Length
|
2 |
+
PATCH_SIZE = 32 # Patch Size
|
3 |
+
PATCH_NUM_LAYERS = 9 # Number of layers in the encoder
|
4 |
+
CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
|
5 |
+
# Batch size for patch during training, 0 for full context
|
6 |
+
PATCH_SAMPLING_BATCH_SIZE = 0
|
7 |
+
# Whether to share weights between the encoder and decoder
|
8 |
+
SHARE_WEIGHTS = False
|
9 |
+
TEMP_DIR = "./tmp"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
convert.py
CHANGED
@@ -1,93 +1,88 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import fitz
|
4 |
-
import subprocess
|
5 |
-
from PIL import Image
|
6 |
-
from music21 import converter
|
7 |
-
from utils import download
|
8 |
-
|
9 |
-
if sys.platform.startswith(
|
10 |
-
apkname =
|
11 |
-
extra_dir =
|
12 |
-
download(
|
13 |
-
filename=apkname,
|
14 |
-
url=
|
15 |
-
)
|
16 |
-
if not os.path.exists(extra_dir):
|
17 |
-
subprocess.run([
|
18 |
-
subprocess.run([f
|
19 |
-
|
20 |
-
|
21 |
-
os.environ[
|
22 |
-
|
23 |
-
else:
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
def abc_to_midi(abc_content, output_midi_path):
|
28 |
-
score = converter.parse(abc_content, format=
|
29 |
-
score.write(
|
30 |
-
return output_midi_path
|
31 |
-
|
32 |
-
|
33 |
-
def abc_to_musicxml(abc_content, output_xml_path):
|
34 |
-
score = converter.parse(abc_content, format=
|
35 |
-
score.write(
|
36 |
-
return output_xml_path
|
37 |
-
|
38 |
-
|
39 |
-
def musicxml_to_mxl(xml_path):
|
40 |
-
mxl_file = xml_path.replace(
|
41 |
-
command = [
|
42 |
-
result = subprocess.run(command)
|
43 |
-
print(result)
|
44 |
-
return mxl_file
|
45 |
-
|
46 |
-
|
47 |
-
def midi2wav(mid_file: str):
|
48 |
-
wav_file = mid_file.replace(
|
49 |
-
command = [
|
50 |
-
result = subprocess.run(command)
|
51 |
-
print(result)
|
52 |
-
return wav_file
|
53 |
-
|
54 |
-
|
55 |
-
def pdf2img(pdf_path: str):
|
56 |
-
output_path = pdf_path.replace(
|
57 |
-
doc = fitz.open(pdf_path)
|
58 |
-
# 创建一个图像列表
|
59 |
-
images = []
|
60 |
-
for page_number in range(doc.page_count):
|
61 |
-
page = doc[page_number]
|
62 |
-
# 将页面渲染为图像
|
63 |
-
image = page.get_pixmap()
|
64 |
-
# 将图像添加到列表
|
65 |
-
images.append(
|
66 |
-
Image.frombytes(
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
pdf_score = mxl_file.replace('.mxl', '.pdf')
|
90 |
-
command = [mscore, "-o", pdf_score, mxl_file]
|
91 |
-
result = subprocess.run(command)
|
92 |
-
print(result)
|
93 |
-
return pdf_score, pdf2img(pdf_score)
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import fitz
|
4 |
+
import subprocess
|
5 |
+
from PIL import Image
|
6 |
+
from music21 import converter
|
7 |
+
from utils import download
|
8 |
+
|
9 |
+
if sys.platform.startswith("linux"):
|
10 |
+
apkname = "MuseScore.AppImage"
|
11 |
+
extra_dir = "squashfs-root"
|
12 |
+
download(
|
13 |
+
filename=apkname,
|
14 |
+
url="https://master.dl.sourceforge.net/project/musescore.mirror/v4.2.0/MuseScore-4.2.0.233521125-x86_64.AppImage?viasf=1",
|
15 |
+
)
|
16 |
+
if not os.path.exists(extra_dir):
|
17 |
+
subprocess.run(["chmod", "+x", f"./{apkname}"])
|
18 |
+
subprocess.run([f"./{apkname}", "--appimage-extract"])
|
19 |
+
|
20 |
+
MSCORE = f"./{extra_dir}/AppRun"
|
21 |
+
os.environ["QT_QPA_PLATFORM"] = "offscreen"
|
22 |
+
|
23 |
+
else:
|
24 |
+
MSCORE = "D:/Program Files/MuseScore 3/bin/MuseScore3.exe"
|
25 |
+
|
26 |
+
|
27 |
+
def abc_to_midi(abc_content, output_midi_path):
|
28 |
+
score = converter.parse(abc_content, format="abc")
|
29 |
+
score.write("midi", fp=output_midi_path)
|
30 |
+
return output_midi_path
|
31 |
+
|
32 |
+
|
33 |
+
def abc_to_musicxml(abc_content, output_xml_path):
|
34 |
+
score = converter.parse(abc_content, format="abc")
|
35 |
+
score.write("musicxml", fp=output_xml_path)
|
36 |
+
return output_xml_path
|
37 |
+
|
38 |
+
|
39 |
+
def musicxml_to_mxl(xml_path):
|
40 |
+
mxl_file = xml_path.replace(".musicxml", ".mxl")
|
41 |
+
command = [MSCORE, "-o", mxl_file, xml_path]
|
42 |
+
result = subprocess.run(command)
|
43 |
+
print(result)
|
44 |
+
return mxl_file
|
45 |
+
|
46 |
+
|
47 |
+
def midi2wav(mid_file: str):
|
48 |
+
wav_file = mid_file.replace(".mid", ".wav")
|
49 |
+
command = [MSCORE, "-o", wav_file, mid_file]
|
50 |
+
result = subprocess.run(command)
|
51 |
+
print(result)
|
52 |
+
return wav_file
|
53 |
+
|
54 |
+
|
55 |
+
def pdf2img(pdf_path: str):
|
56 |
+
output_path = pdf_path.replace(".pdf", ".jpg")
|
57 |
+
doc = fitz.open(pdf_path)
|
58 |
+
# 创建一个图像列表
|
59 |
+
images = []
|
60 |
+
for page_number in range(doc.page_count):
|
61 |
+
page = doc[page_number]
|
62 |
+
# 将页面渲染为图像
|
63 |
+
image = page.get_pixmap()
|
64 |
+
# 将图像添加到列表
|
65 |
+
images.append(
|
66 |
+
Image.frombytes("RGB", [image.width, image.height], image.samples)
|
67 |
+
)
|
68 |
+
# 竖向合并图像
|
69 |
+
merged_image = Image.new(
|
70 |
+
"RGB", (images[0].width, sum(image.height for image in images))
|
71 |
+
)
|
72 |
+
y_offset = 0
|
73 |
+
for image in images:
|
74 |
+
merged_image.paste(image, (0, y_offset))
|
75 |
+
y_offset += image.height
|
76 |
+
# 保存合并后的图像为JPG
|
77 |
+
merged_image.save(output_path, "JPEG")
|
78 |
+
# 关闭PDF文档
|
79 |
+
doc.close()
|
80 |
+
return output_path
|
81 |
+
|
82 |
+
|
83 |
+
def mxl2jpg(mxl_file: str):
|
84 |
+
pdf_score = mxl_file.replace(".mxl", ".pdf")
|
85 |
+
command = [MSCORE, "-o", pdf_score, mxl_file]
|
86 |
+
result = subprocess.run(command)
|
87 |
+
print(result)
|
88 |
+
return pdf_score, pdf2img(pdf_score)
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
-
transformers==4.18.0
|
2 |
-
samplings==0.1.7
|
3 |
-
unidecode
|
4 |
-
music21
|
5 |
-
autopep8
|
6 |
-
pillow==9.4.0
|
7 |
-
gradio
|
8 |
-
pymupdf
|
9 |
torch
|
|
|
1 |
+
transformers==4.18.0
|
2 |
+
samplings==0.1.7
|
3 |
+
unidecode
|
4 |
+
music21
|
5 |
+
autopep8
|
6 |
+
pillow==9.4.0
|
7 |
+
gradio==4.8.0
|
8 |
+
pymupdf
|
9 |
torch
|
utils.py
CHANGED
@@ -1,394 +1,377 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import torch
|
4 |
-
import random
|
5 |
-
|
6 |
-
from
|
7 |
-
from
|
8 |
-
from
|
9 |
-
from
|
10 |
-
from
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
""
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
"""
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
)
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
self.patch_level_decoder
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
:
|
311 |
-
:
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
)
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
return
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
def __init__(self, items, patchilizer):
|
379 |
-
self.texts = []
|
380 |
-
|
381 |
-
for item in tqdm(items):
|
382 |
-
text = item["control code"] + "\n".join(
|
383 |
-
item["abc notation"].split("\n")[1:]
|
384 |
-
)
|
385 |
-
input_patch = patchilizer.encode(text, add_special_patches=True)
|
386 |
-
input_patch = torch.tensor(input_patch)
|
387 |
-
if torch.sum(input_patch) != 0:
|
388 |
-
self.texts.append(input_patch)
|
389 |
-
|
390 |
-
def __len__(self):
|
391 |
-
return len(self.texts)
|
392 |
-
|
393 |
-
def __getitem__(self, idx):
|
394 |
-
return self.texts[idx]
|
|
|
1 |
+
import re
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import requests
|
6 |
+
from config import *
|
7 |
+
from tqdm import tqdm
|
8 |
+
from unidecode import unidecode
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel
|
11 |
+
from samplings import top_p_sampling, top_k_sampling, temperature_sampling
|
12 |
+
|
13 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
14 |
+
|
15 |
+
|
16 |
+
def download(filename: str, url: str):
|
17 |
+
try:
|
18 |
+
response = requests.get(url, stream=True)
|
19 |
+
total_size = int(response.headers.get("content-length", 0))
|
20 |
+
chunk_size = 1024
|
21 |
+
|
22 |
+
with open(filename, "wb") as file, tqdm(
|
23 |
+
desc=f"Downloading {filename} from '{url}'...",
|
24 |
+
total=total_size,
|
25 |
+
unit="B",
|
26 |
+
unit_scale=True,
|
27 |
+
unit_divisor=1024,
|
28 |
+
) as bar:
|
29 |
+
for data in response.iter_content(chunk_size=chunk_size):
|
30 |
+
size = file.write(data)
|
31 |
+
bar.update(size)
|
32 |
+
|
33 |
+
except Exception as e:
|
34 |
+
print(f"Error: {e}, retrying...")
|
35 |
+
time.sleep(10)
|
36 |
+
download(filename, url)
|
37 |
+
|
38 |
+
|
39 |
+
class Patchilizer:
|
40 |
+
"""
|
41 |
+
A class for converting music bars to patches and vice versa.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self):
|
45 |
+
self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
|
46 |
+
self.regexPattern = f"({'|'.join(map(re.escape, self.delimiters))})"
|
47 |
+
self.pad_token_id = 0
|
48 |
+
self.bos_token_id = 1
|
49 |
+
self.eos_token_id = 2
|
50 |
+
|
51 |
+
def split_bars(self, body):
|
52 |
+
"""
|
53 |
+
Split a body of music into individual bars.
|
54 |
+
"""
|
55 |
+
bars = re.split(self.regexPattern, "".join(body))
|
56 |
+
bars = list(filter(None, bars))
|
57 |
+
# remove empty strings
|
58 |
+
if bars[0] in self.delimiters:
|
59 |
+
bars[1] = bars[0] + bars[1]
|
60 |
+
bars = bars[1:]
|
61 |
+
|
62 |
+
bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
|
63 |
+
return bars
|
64 |
+
|
65 |
+
def bar2patch(self, bar, patch_size=PATCH_SIZE):
|
66 |
+
"""
|
67 |
+
Convert a bar into a patch of specified length.
|
68 |
+
"""
|
69 |
+
patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
|
70 |
+
patch = patch[:patch_size]
|
71 |
+
patch += [self.pad_token_id] * (patch_size - len(patch))
|
72 |
+
return patch
|
73 |
+
|
74 |
+
def patch2bar(self, patch):
|
75 |
+
"""
|
76 |
+
Convert a patch into a bar.
|
77 |
+
"""
|
78 |
+
return "".join(
|
79 |
+
chr(idx) if idx > self.eos_token_id else ""
|
80 |
+
for idx in patch
|
81 |
+
if idx != self.eos_token_id
|
82 |
+
)
|
83 |
+
|
84 |
+
def encode(
|
85 |
+
self,
|
86 |
+
abc_code,
|
87 |
+
patch_length=PATCH_LENGTH,
|
88 |
+
patch_size=PATCH_SIZE,
|
89 |
+
add_special_patches=False,
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
Encode music into patches of specified length.
|
93 |
+
"""
|
94 |
+
lines = unidecode(abc_code).split("\n")
|
95 |
+
lines = list(filter(None, lines)) # remove empty lines
|
96 |
+
|
97 |
+
body = ""
|
98 |
+
patches = []
|
99 |
+
|
100 |
+
for line in lines:
|
101 |
+
if len(line) > 1 and (
|
102 |
+
(line[0].isalpha() and line[1] == ":") or line.startswith("%%score")
|
103 |
+
):
|
104 |
+
if body:
|
105 |
+
bars = self.split_bars(body)
|
106 |
+
patches.extend(
|
107 |
+
self.bar2patch(
|
108 |
+
bar + "\n" if idx == len(bars) - 1 else bar, patch_size
|
109 |
+
)
|
110 |
+
for idx, bar in enumerate(bars)
|
111 |
+
)
|
112 |
+
body = ""
|
113 |
+
|
114 |
+
patches.append(self.bar2patch(line + "\n", patch_size))
|
115 |
+
|
116 |
+
else:
|
117 |
+
body += line + "\n"
|
118 |
+
|
119 |
+
if body:
|
120 |
+
patches.extend(
|
121 |
+
self.bar2patch(bar, patch_size) for bar in self.split_bars(body)
|
122 |
+
)
|
123 |
+
|
124 |
+
if add_special_patches:
|
125 |
+
bos_patch = [self.bos_token_id] * (patch_size - 1) + [self.eos_token_id]
|
126 |
+
eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size - 1)
|
127 |
+
patches = [bos_patch] + patches + [eos_patch]
|
128 |
+
|
129 |
+
return patches[:patch_length]
|
130 |
+
|
131 |
+
def decode(self, patches):
|
132 |
+
"""
|
133 |
+
Decode patches into music.
|
134 |
+
"""
|
135 |
+
return "".join(self.patch2bar(patch) for patch in patches)
|
136 |
+
|
137 |
+
|
138 |
+
class PatchLevelDecoder(PreTrainedModel):
|
139 |
+
"""
|
140 |
+
An Patch-level Decoder model for generating patch features in an auto-regressive manner.
|
141 |
+
It inherits PreTrainedModel from transformers.
|
142 |
+
"""
|
143 |
+
|
144 |
+
def __init__(self, config):
|
145 |
+
super().__init__(config)
|
146 |
+
self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
|
147 |
+
torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
|
148 |
+
self.base = GPT2Model(config)
|
149 |
+
|
150 |
+
def forward(self, patches: torch.Tensor) -> torch.Tensor:
|
151 |
+
"""
|
152 |
+
The forward pass of the patch-level decoder model.
|
153 |
+
:param patches: the patches to be encoded
|
154 |
+
:return: the encoded patches
|
155 |
+
"""
|
156 |
+
patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
|
157 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
|
158 |
+
patches = self.patch_embedding(patches.to(self.device))
|
159 |
+
|
160 |
+
return self.base(inputs_embeds=patches)
|
161 |
+
|
162 |
+
|
163 |
+
class CharLevelDecoder(PreTrainedModel):
|
164 |
+
"""
|
165 |
+
A Char-level Decoder model for generating the characters within each bar patch sequentially.
|
166 |
+
It inherits PreTrainedModel from transformers.
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(self, config):
|
170 |
+
super().__init__(config)
|
171 |
+
self.pad_token_id = 0
|
172 |
+
self.bos_token_id = 1
|
173 |
+
self.eos_token_id = 2
|
174 |
+
self.base = GPT2LMHeadModel(config)
|
175 |
+
|
176 |
+
def forward(
|
177 |
+
self,
|
178 |
+
encoded_patches: torch.Tensor,
|
179 |
+
target_patches: torch.Tensor,
|
180 |
+
patch_sampling_batch_size: int,
|
181 |
+
):
|
182 |
+
"""
|
183 |
+
The forward pass of the char-level decoder model.
|
184 |
+
:param encoded_patches: the encoded patches
|
185 |
+
:param target_patches: the target patches
|
186 |
+
:return: the decoded patches
|
187 |
+
"""
|
188 |
+
# preparing the labels for model training
|
189 |
+
target_masks = target_patches == self.pad_token_id
|
190 |
+
labels = target_patches.clone().masked_fill_(target_masks, -100)
|
191 |
+
|
192 |
+
# masking the labels for model training
|
193 |
+
target_masks = torch.ones_like(labels)
|
194 |
+
target_masks = target_masks.masked_fill_(labels == -100, 0)
|
195 |
+
|
196 |
+
# select patches
|
197 |
+
if (
|
198 |
+
patch_sampling_batch_size != 0
|
199 |
+
and patch_sampling_batch_size < target_patches.shape[0]
|
200 |
+
):
|
201 |
+
indices = list(range(len(target_patches)))
|
202 |
+
random.shuffle(indices)
|
203 |
+
selected_indices = sorted(indices[:patch_sampling_batch_size])
|
204 |
+
|
205 |
+
target_patches = target_patches[selected_indices, :]
|
206 |
+
target_masks = target_masks[selected_indices, :]
|
207 |
+
encoded_patches = encoded_patches[selected_indices, :]
|
208 |
+
labels = labels[selected_indices, :]
|
209 |
+
|
210 |
+
# get input embeddings
|
211 |
+
inputs_embeds = torch.nn.functional.embedding(
|
212 |
+
target_patches, self.base.transformer.wte.weight
|
213 |
+
)
|
214 |
+
|
215 |
+
# concatenate the encoded patches with the input embeddings
|
216 |
+
inputs_embeds = torch.cat(
|
217 |
+
(encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1
|
218 |
+
)
|
219 |
+
|
220 |
+
return self.base(
|
221 |
+
inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=labels
|
222 |
+
)
|
223 |
+
|
224 |
+
def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
|
225 |
+
"""
|
226 |
+
The generate function for generating a patch based on the encoded patch and already generated tokens.
|
227 |
+
:param encoded_patch: the encoded patch
|
228 |
+
:param tokens: already generated tokens in the patch
|
229 |
+
:return: the probability distribution of next token
|
230 |
+
"""
|
231 |
+
encoded_patch = encoded_patch.reshape(1, 1, -1)
|
232 |
+
tokens = tokens.reshape(1, -1)
|
233 |
+
|
234 |
+
# Get input embeddings
|
235 |
+
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
|
236 |
+
|
237 |
+
# Concatenate the encoded patch with the input embeddings
|
238 |
+
tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
|
239 |
+
|
240 |
+
# Get output from model
|
241 |
+
outputs = self.base(inputs_embeds=tokens)
|
242 |
+
|
243 |
+
# Get probabilities of next token
|
244 |
+
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
|
245 |
+
|
246 |
+
return probs
|
247 |
+
|
248 |
+
|
249 |
+
class TunesFormer(PreTrainedModel):
|
250 |
+
"""
|
251 |
+
TunesFormer is a hierarchical music generation model based on bar patching.
|
252 |
+
It includes a patch-level decoder and a character-level decoder.
|
253 |
+
It inherits PreTrainedModel from transformers.
|
254 |
+
"""
|
255 |
+
|
256 |
+
def __init__(self, encoder_config, decoder_config, share_weights=False):
|
257 |
+
super().__init__(encoder_config)
|
258 |
+
self.pad_token_id = 0
|
259 |
+
self.bos_token_id = 1
|
260 |
+
self.eos_token_id = 2
|
261 |
+
if share_weights:
|
262 |
+
max_layers = max(
|
263 |
+
encoder_config.num_hidden_layers, decoder_config.num_hidden_layers
|
264 |
+
)
|
265 |
+
|
266 |
+
max_context_size = max(encoder_config.max_length, decoder_config.max_length)
|
267 |
+
|
268 |
+
max_position_embeddings = max(
|
269 |
+
encoder_config.max_position_embeddings,
|
270 |
+
decoder_config.max_position_embeddings,
|
271 |
+
)
|
272 |
+
|
273 |
+
encoder_config.num_hidden_layers = max_layers
|
274 |
+
encoder_config.max_length = max_context_size
|
275 |
+
encoder_config.max_position_embeddings = max_position_embeddings
|
276 |
+
decoder_config.num_hidden_layers = max_layers
|
277 |
+
decoder_config.max_length = max_context_size
|
278 |
+
decoder_config.max_position_embeddings = max_position_embeddings
|
279 |
+
|
280 |
+
self.patch_level_decoder = PatchLevelDecoder(encoder_config)
|
281 |
+
self.char_level_decoder = CharLevelDecoder(decoder_config)
|
282 |
+
|
283 |
+
if share_weights:
|
284 |
+
self.patch_level_decoder.base = self.char_level_decoder.base.transformer
|
285 |
+
|
286 |
+
def forward(
|
287 |
+
self,
|
288 |
+
patches: torch.Tensor,
|
289 |
+
patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE,
|
290 |
+
):
|
291 |
+
"""
|
292 |
+
The forward pass of the TunesFormer model.
|
293 |
+
:param patches: the patches to be both encoded and decoded
|
294 |
+
:return: the decoded patches
|
295 |
+
"""
|
296 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
297 |
+
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
|
298 |
+
|
299 |
+
return self.char_level_decoder(
|
300 |
+
encoded_patches.squeeze(0)[:-1, :],
|
301 |
+
patches.squeeze(0)[1:, :],
|
302 |
+
patch_sampling_batch_size,
|
303 |
+
)
|
304 |
+
|
305 |
+
def generate(
|
306 |
+
self,
|
307 |
+
patches: torch.Tensor,
|
308 |
+
tokens: torch.Tensor,
|
309 |
+
top_p: float = 1,
|
310 |
+
top_k: int = 0,
|
311 |
+
temperature: float = 1,
|
312 |
+
seed: int = None,
|
313 |
+
):
|
314 |
+
"""
|
315 |
+
The generate function for generating patches based on patches.
|
316 |
+
:param patches: the patches to be encoded
|
317 |
+
:return: the generated patches
|
318 |
+
"""
|
319 |
+
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
320 |
+
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
|
321 |
+
|
322 |
+
if tokens == None:
|
323 |
+
tokens = torch.tensor([self.bos_token_id], device=self.device)
|
324 |
+
|
325 |
+
generated_patch = []
|
326 |
+
random.seed(seed)
|
327 |
+
|
328 |
+
while True:
|
329 |
+
if seed != None:
|
330 |
+
n_seed = random.randint(0, 1000000)
|
331 |
+
random.seed(n_seed)
|
332 |
+
|
333 |
+
else:
|
334 |
+
n_seed = None
|
335 |
+
|
336 |
+
prob = (
|
337 |
+
self.char_level_decoder.generate(encoded_patches[0][-1], tokens)
|
338 |
+
.cpu()
|
339 |
+
.detach()
|
340 |
+
.numpy()
|
341 |
+
)
|
342 |
+
|
343 |
+
prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
|
344 |
+
prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
|
345 |
+
|
346 |
+
token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
|
347 |
+
|
348 |
+
generated_patch.append(token)
|
349 |
+
if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
|
350 |
+
break
|
351 |
+
|
352 |
+
else:
|
353 |
+
tokens = torch.cat(
|
354 |
+
(tokens, torch.tensor([token], device=self.device)), dim=0
|
355 |
+
)
|
356 |
+
|
357 |
+
return generated_patch, n_seed
|
358 |
+
|
359 |
+
|
360 |
+
class PatchilizedData(Dataset):
|
361 |
+
def __init__(self, items, patchilizer):
|
362 |
+
self.texts = []
|
363 |
+
|
364 |
+
for item in tqdm(items):
|
365 |
+
text = item["control code"] + "\n".join(
|
366 |
+
item["abc notation"].split("\n")[1:]
|
367 |
+
)
|
368 |
+
input_patch = patchilizer.encode(text, add_special_patches=True)
|
369 |
+
input_patch = torch.tensor(input_patch)
|
370 |
+
if torch.sum(input_patch) != 0:
|
371 |
+
self.texts.append(input_patch)
|
372 |
+
|
373 |
+
def __len__(self):
|
374 |
+
return len(self.texts)
|
375 |
+
|
376 |
+
def __getitem__(self, idx):
|
377 |
+
return self.texts[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|