Monan Zhou commited on
Commit
6c48757
1 Parent(s): dc7b3e1

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +267 -237
  2. config.py +9 -19
  3. convert.py +88 -93
  4. requirements.txt +8 -8
  5. utils.py +377 -394
app.py CHANGED
@@ -1,237 +1,267 @@
1
- import re
2
- import os
3
- import time
4
- import torch
5
- import shutil
6
- import argparse
7
- import gradio as gr
8
- from utils import *
9
- from config import *
10
- from convert import *
11
- from transformers import GPT2Config
12
- import warnings
13
-
14
- warnings.filterwarnings("ignore")
15
-
16
-
17
- def get_args(parser):
18
- parser.add_argument(
19
- "-num_tunes",
20
- type=int,
21
- default=1,
22
- help="the number of independently computed returned tunes",
23
- )
24
- parser.add_argument(
25
- "-max_patch",
26
- type=int,
27
- default=128,
28
- help="integer to define the maximum length in tokens of each tune",
29
- )
30
- parser.add_argument(
31
- "-top_p",
32
- type=float,
33
- default=0.8,
34
- help="float to define the tokens that are within the sample operation of text generation",
35
- )
36
- parser.add_argument(
37
- "-top_k",
38
- type=int,
39
- default=8,
40
- help="integer to define the tokens that are within the sample operation of text generation",
41
- )
42
- parser.add_argument(
43
- "-temperature",
44
- type=float,
45
- default=1.2,
46
- help="the temperature of the sampling operation",
47
- )
48
- parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
49
- parser.add_argument(
50
- "-show_control_code",
51
- type=bool,
52
- default=True,
53
- help="whether to show control code",
54
- )
55
- args = parser.parse_args()
56
-
57
- return args
58
-
59
-
60
- def generate_abc(args, region):
61
- patchilizer = Patchilizer()
62
-
63
- patch_config = GPT2Config(
64
- num_hidden_layers=PATCH_NUM_LAYERS,
65
- max_length=PATCH_LENGTH,
66
- max_position_embeddings=PATCH_LENGTH,
67
- vocab_size=1,
68
- )
69
-
70
- char_config = GPT2Config(
71
- num_hidden_layers=CHAR_NUM_LAYERS,
72
- max_length=PATCH_SIZE,
73
- max_position_embeddings=PATCH_SIZE,
74
- vocab_size=128,
75
- )
76
-
77
- model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
78
-
79
- filename = WEIGHT_PATH
80
-
81
- if os.path.exists(filename):
82
- print(f"Weights already exist at '{filename}'. Loading...")
83
-
84
- else:
85
- download()
86
-
87
- checkpoint = torch.load(filename, map_location=torch.device("cpu"))
88
- model.load_state_dict(checkpoint["model"])
89
- model = model.to(device)
90
- model.eval()
91
-
92
- prompt = template(region)
93
-
94
- tunes = ""
95
- num_tunes = args.num_tunes
96
- max_patch = args.max_patch
97
- top_p = args.top_p
98
- top_k = args.top_k
99
- temperature = args.temperature
100
- seed = args.seed
101
- show_control_code = args.show_control_code
102
-
103
- print(" HYPERPARAMETERS ".center(60, "#"), "\n")
104
- args = vars(args)
105
-
106
- for key in args.keys():
107
- print(f"{key}: {str(args[key])}")
108
-
109
- print("\n", " OUTPUT TUNES ".center(60, "#"))
110
-
111
- start_time = time.time()
112
-
113
- for i in range(num_tunes):
114
- title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
115
- tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
116
- lines = re.split(r"(\n)", tune)
117
- tune = ""
118
- skip = False
119
- for line in lines:
120
- if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
121
- if not skip:
122
- print(line, end="")
123
- tune += line
124
-
125
- skip = False
126
-
127
- else:
128
- skip = True
129
-
130
- input_patches = torch.tensor(
131
- [patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=device
132
- )
133
-
134
- if tune == "":
135
- tokens = None
136
-
137
- else:
138
- prefix = patchilizer.decode(input_patches[0])
139
- remaining_tokens = prompt[len(prefix) :]
140
- tokens = torch.tensor(
141
- [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
142
- device=device,
143
- )
144
-
145
- while input_patches.shape[1] < max_patch:
146
- predicted_patch, seed = model.generate(
147
- input_patches,
148
- tokens,
149
- top_p=top_p,
150
- top_k=top_k,
151
- temperature=temperature,
152
- seed=seed,
153
- )
154
- tokens = None
155
-
156
- if predicted_patch[0] != patchilizer.eos_token_id:
157
- next_bar = patchilizer.decode([predicted_patch])
158
-
159
- if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
160
- print(next_bar, end="")
161
- tune += next_bar
162
-
163
- if next_bar == "":
164
- break
165
-
166
- next_bar = remaining_tokens + next_bar
167
- remaining_tokens = ""
168
-
169
- predicted_patch = torch.tensor(
170
- patchilizer.bar2patch(next_bar), device=device
171
- ).unsqueeze(0)
172
-
173
- input_patches = torch.cat(
174
- [input_patches, predicted_patch.unsqueeze(0)], dim=1
175
- )
176
-
177
- else:
178
- break
179
-
180
- tunes += f"{tune}\n\n"
181
- print("\n")
182
-
183
- print("Generation time: {:.2f} seconds".format(time.time() - start_time))
184
- os.makedirs("./tmp", exist_ok=True)
185
- timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
186
- out_midi = abc_to_midi(tunes, f"./tmp/[{region}]{timestamp}.mid")
187
- out_xml = abc_to_musicxml(tunes, f"./tmp/[{region}]{timestamp}.musicxml")
188
- out_mxl = musicxml_to_mxl(f"./tmp/[{region}]{timestamp}.musicxml")
189
- pdf_file, jpg_file = mxl2jpg(out_mxl)
190
- wav_file = midi2wav(out_midi)
191
-
192
- return tunes, out_midi, pdf_file, out_xml, out_mxl, jpg_file, wav_file
193
-
194
-
195
- def inference(region):
196
- if os.path.exists("./tmp"):
197
- shutil.rmtree("./tmp")
198
-
199
- parser = argparse.ArgumentParser()
200
- args = get_args(parser)
201
- return generate_abc(args, region)
202
-
203
-
204
- with gr.Blocks() as demo:
205
- with gr.Row():
206
- with gr.Column():
207
- region_opt = gr.Dropdown(
208
- choices=["Mondstadt", "Liyue", "Inazuma", "Sumeru", "Fontaine"],
209
- value="Mondstadt",
210
- label="Region genre",
211
- )
212
- gen_btn = gr.Button("Generate")
213
-
214
- with gr.Column():
215
- wav_output = gr.Audio(label="Audio", type="filepath")
216
- dld_midi = gr.components.File(label="Download MIDI")
217
- pdf_score = gr.components.File(label="Download PDF score")
218
- dld_xml = gr.components.File(label="Download MusicXML")
219
- dld_mxl = gr.components.File(label="Download MXL")
220
- abc_output = gr.Textbox(label="abc score", show_copy_button=True)
221
- img_score = gr.Image(label="Staff", type="filepath")
222
-
223
- gen_btn.click(
224
- inference,
225
- inputs=region_opt,
226
- outputs=[
227
- abc_output,
228
- dld_midi,
229
- pdf_score,
230
- dld_xml,
231
- dld_mxl,
232
- img_score,
233
- wav_output,
234
- ],
235
- )
236
-
237
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import time
4
+ import torch
5
+ import shutil
6
+ import argparse
7
+ import gradio as gr
8
+ from config import *
9
+ from utils import Patchilizer, TunesFormer, DEVICE
10
+ from convert import abc_to_midi, abc_to_musicxml, musicxml_to_mxl, mxl2jpg, midi2wav
11
+ from modelscope import snapshot_download
12
+ from transformers import GPT2Config
13
+ import warnings
14
+
15
+ warnings.filterwarnings("ignore")
16
+
17
+ # 模型下载
18
+ MODEL_DIR = snapshot_download("MuGeminorum/hoyoGPT")
19
+
20
+
21
+ def get_args(parser: argparse.ArgumentParser):
22
+ parser.add_argument(
23
+ "-num_tunes",
24
+ type=int,
25
+ default=1,
26
+ help="the number of independently computed returned tunes",
27
+ )
28
+ parser.add_argument(
29
+ "-max_patch",
30
+ type=int,
31
+ default=128,
32
+ help="integer to define the maximum length in tokens of each tune",
33
+ )
34
+ parser.add_argument(
35
+ "-top_p",
36
+ type=float,
37
+ default=0.8,
38
+ help="float to define the tokens that are within the sample operation of text generation",
39
+ )
40
+ parser.add_argument(
41
+ "-top_k",
42
+ type=int,
43
+ default=8,
44
+ help="integer to define the tokens that are within the sample operation of text generation",
45
+ )
46
+ parser.add_argument(
47
+ "-temperature",
48
+ type=float,
49
+ default=1.2,
50
+ help="the temperature of the sampling operation",
51
+ )
52
+ parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
53
+ parser.add_argument(
54
+ "-show_control_code",
55
+ type=bool,
56
+ default=False,
57
+ help="whether to show control code",
58
+ )
59
+ args = parser.parse_args()
60
+
61
+ return args
62
+
63
+
64
+ def generate_abc(args, epochs: str, region: str):
65
+ patchilizer = Patchilizer()
66
+
67
+ patch_config = GPT2Config(
68
+ num_hidden_layers=PATCH_NUM_LAYERS,
69
+ max_length=PATCH_LENGTH,
70
+ max_position_embeddings=PATCH_LENGTH,
71
+ vocab_size=1,
72
+ )
73
+
74
+ char_config = GPT2Config(
75
+ num_hidden_layers=CHAR_NUM_LAYERS,
76
+ max_length=PATCH_SIZE,
77
+ max_position_embeddings=PATCH_SIZE,
78
+ vocab_size=128,
79
+ )
80
+
81
+ model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
82
+ filename = f"{MODEL_DIR}/{epochs}/weights.pth"
83
+ checkpoint = torch.load(filename, map_location=torch.device("cpu"))
84
+ model.load_state_dict(checkpoint["model"])
85
+ model = model.to(DEVICE)
86
+ model.eval()
87
+
88
+ prompt = f"A:{region}\n"
89
+
90
+ tunes = ""
91
+ num_tunes = args.num_tunes
92
+ max_patch = args.max_patch
93
+ top_p = args.top_p
94
+ top_k = args.top_k
95
+ temperature = args.temperature
96
+ seed = args.seed
97
+ show_control_code = args.show_control_code
98
+
99
+ print(" HYPERPARAMETERS ".center(60, "#"), "\n")
100
+ arg_dict: dict = vars(args)
101
+
102
+ for key in arg_dict.keys():
103
+ print(f"{key}: {str(arg_dict[key])}")
104
+
105
+ print("\n", " OUTPUT TUNES ".center(60, "#"))
106
+
107
+ start_time = time.time()
108
+ for i in range(num_tunes):
109
+ title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
110
+ tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
111
+ lines = re.split(r"(\n)", tune)
112
+ tune = ""
113
+ skip = False
114
+ for line in lines:
115
+ if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
116
+ if not skip:
117
+ print(line, end="")
118
+ tune += line
119
+
120
+ skip = False
121
+
122
+ else:
123
+ skip = True
124
+
125
+ input_patches = torch.tensor(
126
+ [patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=DEVICE
127
+ )
128
+
129
+ if tune == "":
130
+ tokens = None
131
+
132
+ else:
133
+ prefix = patchilizer.decode(input_patches[0])
134
+ remaining_tokens = prompt[len(prefix) :]
135
+ tokens = torch.tensor(
136
+ [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
137
+ device=DEVICE,
138
+ )
139
+
140
+ while input_patches.shape[1] < max_patch:
141
+ predicted_patch, seed = model.generate(
142
+ input_patches,
143
+ tokens,
144
+ top_p=top_p,
145
+ top_k=top_k,
146
+ temperature=temperature,
147
+ seed=seed,
148
+ )
149
+ tokens = None
150
+
151
+ if predicted_patch[0] != patchilizer.eos_token_id:
152
+ next_bar = patchilizer.decode([predicted_patch])
153
+
154
+ if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
155
+ print(next_bar, end="")
156
+ tune += next_bar
157
+
158
+ if next_bar == "":
159
+ break
160
+
161
+ next_bar = remaining_tokens + next_bar
162
+ remaining_tokens = ""
163
+
164
+ predicted_patch = torch.tensor(
165
+ patchilizer.bar2patch(next_bar), device=DEVICE
166
+ ).unsqueeze(0)
167
+
168
+ input_patches = torch.cat(
169
+ [input_patches, predicted_patch.unsqueeze(0)], dim=1
170
+ )
171
+
172
+ else:
173
+ break
174
+
175
+ tunes += f"{tune}\n\n"
176
+ print("\n")
177
+
178
+ print("Generation time: {:.2f} seconds".format(time.time() - start_time))
179
+ os.makedirs(TEMP_DIR, exist_ok=True)
180
+ timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
181
+ try:
182
+ out_midi = abc_to_midi(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.mid")
183
+ out_xml = abc_to_musicxml(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
184
+ out_mxl = musicxml_to_mxl(f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
185
+ pdf_file, jpg_file = mxl2jpg(out_mxl)
186
+ wav_file = midi2wav(out_midi)
187
+
188
+ return tunes, out_midi, pdf_file, out_xml, out_mxl, jpg_file, wav_file
189
+
190
+ except Exception as e:
191
+ print(f"Invalid abc generated: {e}, retrying...")
192
+ return generate_abc(args, epochs, region)
193
+
194
+
195
+ def inference(epochs: str, region: str):
196
+ Teyvat = {
197
+ "蒙德": "Mondstadt",
198
+ "璃月": "Liyue",
199
+ "稻妻": "Inazuma",
200
+ "须弥": "Sumeru",
201
+ "枫丹": "Fontaine",
202
+ }
203
+
204
+ if os.path.exists(TEMP_DIR):
205
+ shutil.rmtree(TEMP_DIR)
206
+
207
+ parser = argparse.ArgumentParser()
208
+ args = get_args(parser)
209
+ return generate_abc(args, epochs, Teyvat[region])
210
+
211
+
212
+ if __name__ == "__main__":
213
+ with gr.Blocks() as demo:
214
+ gr.Markdown(
215
+ """<center>欢迎使用此创空间,此创空间由bilibili <a href="https://space.bilibili.com/30620472">@亦真亦幻Studio</a> 基于 Tunesformer 开源项目制作,完全免费。</center>"""
216
+ )
217
+ with gr.Row():
218
+ with gr.Column():
219
+ weight_opt = gr.Dropdown(
220
+ choices=["5", "15"],
221
+ value="15",
222
+ label="模型选择(epochs)",
223
+ )
224
+ region_opt = gr.Dropdown(
225
+ choices=["蒙德", "璃月", "稻妻", "须弥", "枫丹"],
226
+ value="蒙德",
227
+ label="地区风格",
228
+ )
229
+ gen_btn = gr.Button("生成")
230
+ gr.Markdown(
231
+ """
232
+ <center>
233
+ 当前模型还在调试中,由于训练数据由 MIDI 转换而来,存在大量谱面不规范问题,导致很多生成结果有谱面不规范等问题。<br>
234
+
235
+ 计划在原神主线杀青后,所有国家地区角色全部开放后,二创音乐会齐全且样本均衡,届时重新微调模型并添加现实风格筛选辅助游戏各国家输出把关,以提升输出区分度与质量。
236
+
237
+ 数据来源:<a href="https://musescore.org">MuseScore</a><br>
238
+ Tag 嵌入数据来源:<a href="https://genshin-impact.fandom.com/wiki/Genshin_Impact_Wiki">Genshin Impact Wiki | Fandom</a><br>
239
+ 模型基础:<a href="https://github.com/sander-wood/tunesformer">Tunesformer</a>
240
+
241
+ 注:崩铁方面数据工程正在运作中,未来也希望随主线杀青而基线化。</center>"""
242
+ )
243
+
244
+ with gr.Column():
245
+ wav_output = gr.Audio(label="音频", type="filepath")
246
+ dld_midi = gr.components.File(label="下载 MIDI")
247
+ pdf_score = gr.components.File(label="下载 PDF 乐谱")
248
+ dld_xml = gr.components.File(label="下载 MusicXML")
249
+ dld_mxl = gr.components.File(label="下载 MXL")
250
+ abc_output = gr.Textbox(label="abc notation", show_copy_button=True)
251
+ img_score = gr.Image(label="五线谱", type="filepath")
252
+
253
+ gen_btn.click(
254
+ inference,
255
+ inputs=[weight_opt, region_opt],
256
+ outputs=[
257
+ abc_output,
258
+ dld_midi,
259
+ pdf_score,
260
+ dld_xml,
261
+ dld_mxl,
262
+ img_score,
263
+ wav_output,
264
+ ],
265
+ )
266
+
267
+ demo.launch()
config.py CHANGED
@@ -1,19 +1,9 @@
1
- PATCH_LENGTH = 128 # Patch Length
2
- PATCH_SIZE = 32 # Patch Size
3
-
4
- PATCH_NUM_LAYERS = 9 # Number of layers in the encoder
5
- CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
6
-
7
- # Number of epochs to train for (if early stopping doesn't intervene)
8
- NUM_EPOCHS = 5 # 32
9
- LEARNING_RATE = 5e-5 # Learning rate for the optimizer
10
- # Batch size for patch during training, 0 for full context
11
- PATCH_SAMPLING_BATCH_SIZE = 0
12
- LOAD_FROM_CHECKPOINT = True # Whether to load weights from a checkpoint
13
- # Whether to share weights between the encoder and decoder
14
- SHARE_WEIGHTS = False
15
- WEIGHT_URL = 'https://huggingface.co/MuGeminorum/hoyoGPT/resolve/main/weights.pth'
16
- ZH_WEIGHT_URL = 'https://www.modelscope.cn/api/v1/models/MuGeminorum/hoyoGPT/repo?Revision=master&FilePath=weights.pth'
17
- WEIGHT_PATH = 'weights.pth'
18
- LOG_PATH = 'logs.txt'
19
- PROMPT_PATH = 'prompt.txt'
 
1
+ PATCH_LENGTH = 128 # Patch Length
2
+ PATCH_SIZE = 32 # Patch Size
3
+ PATCH_NUM_LAYERS = 9 # Number of layers in the encoder
4
+ CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
5
+ # Batch size for patch during training, 0 for full context
6
+ PATCH_SAMPLING_BATCH_SIZE = 0
7
+ # Whether to share weights between the encoder and decoder
8
+ SHARE_WEIGHTS = False
9
+ TEMP_DIR = "./tmp"
 
 
 
 
 
 
 
 
 
 
convert.py CHANGED
@@ -1,93 +1,88 @@
1
- import os
2
- import sys
3
- import fitz
4
- import subprocess
5
- from PIL import Image
6
- from music21 import converter
7
- from utils import download
8
-
9
- if sys.platform.startswith('linux'):
10
- apkname = 'MuseScore.AppImage'
11
- extra_dir = 'squashfs-root'
12
- download(
13
- filename=apkname,
14
- url='https://cdn.jsdelivr.net/musescore/v4.2.0/MuseScore-4.2.0.233521125-x86_64.AppImage'
15
- )
16
- if not os.path.exists(extra_dir):
17
- subprocess.run(['chmod', '+x', f'./{apkname}'])
18
- subprocess.run([f'./{apkname}', '--appimage-extract'])
19
-
20
- mscore = f'./{extra_dir}/AppRun'
21
- os.environ['QT_QPA_PLATFORM'] = 'offscreen'
22
-
23
- else:
24
- mscore = "D:/Program Files/MuseScore 3/bin/MuseScore3.exe"
25
-
26
-
27
- def abc_to_midi(abc_content, output_midi_path):
28
- score = converter.parse(abc_content, format='abc')
29
- score.write('midi', fp=output_midi_path)
30
- return output_midi_path
31
-
32
-
33
- def abc_to_musicxml(abc_content, output_xml_path):
34
- score = converter.parse(abc_content, format='abc')
35
- score.write('musicxml', fp=output_xml_path)
36
- return output_xml_path
37
-
38
-
39
- def musicxml_to_mxl(xml_path):
40
- mxl_file = xml_path.replace('.musicxml', '.mxl')
41
- command = [mscore, "-o", mxl_file, xml_path]
42
- result = subprocess.run(command)
43
- print(result)
44
- return mxl_file
45
-
46
-
47
- def midi2wav(mid_file: str):
48
- wav_file = mid_file.replace('.mid', '.wav')
49
- command = [mscore, "-o", wav_file, mid_file]
50
- result = subprocess.run(command)
51
- print(result)
52
- return wav_file
53
-
54
-
55
- def pdf2img(pdf_path: str):
56
- output_path = pdf_path.replace('.pdf', '.jpg')
57
- doc = fitz.open(pdf_path)
58
- # 创建一个图像列表
59
- images = []
60
- for page_number in range(doc.page_count):
61
- page = doc[page_number]
62
- # 将页面渲染为图像
63
- image = page.get_pixmap()
64
- # 将图像添加到列表
65
- images.append(
66
- Image.frombytes(
67
- "RGB",
68
- [image.width, image.height],
69
- image.samples
70
- )
71
- )
72
- # 竖向合并图像
73
- merged_image = Image.new(
74
- "RGB",
75
- (images[0].width, sum(image.height for image in images))
76
- )
77
- y_offset = 0
78
- for image in images:
79
- merged_image.paste(image, (0, y_offset))
80
- y_offset += image.height
81
- # 保存合并后的图像为JPG
82
- merged_image.save(output_path, "JPEG")
83
- # 关闭PDF文档
84
- doc.close()
85
- return output_path
86
-
87
-
88
- def mxl2jpg(mxl_file: str):
89
- pdf_score = mxl_file.replace('.mxl', '.pdf')
90
- command = [mscore, "-o", pdf_score, mxl_file]
91
- result = subprocess.run(command)
92
- print(result)
93
- return pdf_score, pdf2img(pdf_score)
 
1
+ import os
2
+ import sys
3
+ import fitz
4
+ import subprocess
5
+ from PIL import Image
6
+ from music21 import converter
7
+ from utils import download
8
+
9
+ if sys.platform.startswith("linux"):
10
+ apkname = "MuseScore.AppImage"
11
+ extra_dir = "squashfs-root"
12
+ download(
13
+ filename=apkname,
14
+ url="https://master.dl.sourceforge.net/project/musescore.mirror/v4.2.0/MuseScore-4.2.0.233521125-x86_64.AppImage?viasf=1",
15
+ )
16
+ if not os.path.exists(extra_dir):
17
+ subprocess.run(["chmod", "+x", f"./{apkname}"])
18
+ subprocess.run([f"./{apkname}", "--appimage-extract"])
19
+
20
+ MSCORE = f"./{extra_dir}/AppRun"
21
+ os.environ["QT_QPA_PLATFORM"] = "offscreen"
22
+
23
+ else:
24
+ MSCORE = "D:/Program Files/MuseScore 3/bin/MuseScore3.exe"
25
+
26
+
27
+ def abc_to_midi(abc_content, output_midi_path):
28
+ score = converter.parse(abc_content, format="abc")
29
+ score.write("midi", fp=output_midi_path)
30
+ return output_midi_path
31
+
32
+
33
+ def abc_to_musicxml(abc_content, output_xml_path):
34
+ score = converter.parse(abc_content, format="abc")
35
+ score.write("musicxml", fp=output_xml_path)
36
+ return output_xml_path
37
+
38
+
39
+ def musicxml_to_mxl(xml_path):
40
+ mxl_file = xml_path.replace(".musicxml", ".mxl")
41
+ command = [MSCORE, "-o", mxl_file, xml_path]
42
+ result = subprocess.run(command)
43
+ print(result)
44
+ return mxl_file
45
+
46
+
47
+ def midi2wav(mid_file: str):
48
+ wav_file = mid_file.replace(".mid", ".wav")
49
+ command = [MSCORE, "-o", wav_file, mid_file]
50
+ result = subprocess.run(command)
51
+ print(result)
52
+ return wav_file
53
+
54
+
55
+ def pdf2img(pdf_path: str):
56
+ output_path = pdf_path.replace(".pdf", ".jpg")
57
+ doc = fitz.open(pdf_path)
58
+ # 创建一个图像列表
59
+ images = []
60
+ for page_number in range(doc.page_count):
61
+ page = doc[page_number]
62
+ # 将页面渲染为图像
63
+ image = page.get_pixmap()
64
+ # 将图像添加到列表
65
+ images.append(
66
+ Image.frombytes("RGB", [image.width, image.height], image.samples)
67
+ )
68
+ # 竖向合并图像
69
+ merged_image = Image.new(
70
+ "RGB", (images[0].width, sum(image.height for image in images))
71
+ )
72
+ y_offset = 0
73
+ for image in images:
74
+ merged_image.paste(image, (0, y_offset))
75
+ y_offset += image.height
76
+ # 保存合并后的图像为JPG
77
+ merged_image.save(output_path, "JPEG")
78
+ # 关闭PDF文档
79
+ doc.close()
80
+ return output_path
81
+
82
+
83
+ def mxl2jpg(mxl_file: str):
84
+ pdf_score = mxl_file.replace(".mxl", ".pdf")
85
+ command = [MSCORE, "-o", pdf_score, mxl_file]
86
+ result = subprocess.run(command)
87
+ print(result)
88
+ return pdf_score, pdf2img(pdf_score)
 
 
 
 
 
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
- transformers==4.18.0
2
- samplings==0.1.7
3
- unidecode
4
- music21
5
- autopep8
6
- pillow==9.4.0
7
- gradio
8
- pymupdf
9
  torch
 
1
+ transformers==4.18.0
2
+ samplings==0.1.7
3
+ unidecode
4
+ music21
5
+ autopep8
6
+ pillow==9.4.0
7
+ gradio==4.8.0
8
+ pymupdf
9
  torch
utils.py CHANGED
@@ -1,394 +1,377 @@
1
- import os
2
- import re
3
- import torch
4
- import random
5
- from config import *
6
- from tqdm import tqdm
7
- from unidecode import unidecode
8
- from torch.utils.data import Dataset
9
- from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel
10
- from samplings import top_p_sampling, top_k_sampling, temperature_sampling
11
-
12
- if torch.cuda.is_available():
13
- device = torch.device("cuda")
14
- else:
15
- device = torch.device("cpu")
16
-
17
-
18
- def template(region):
19
- return f'''A:{region}
20
- S:2
21
- B:9
22
- E:4
23
- B:9
24
- L:1/8
25
- M:3/4
26
- K:D
27
- de |"D"'''
28
-
29
-
30
- def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
31
- import time
32
- import requests
33
-
34
- try:
35
- response = requests.get(url, stream=True)
36
- total_size = int(response.headers.get("content-length", 0))
37
- chunk_size = 1024
38
-
39
- with open(filename, "wb") as file, tqdm(
40
- desc=f"Downloading weights to '{filename}'...",
41
- total=total_size,
42
- unit="B",
43
- unit_scale=True,
44
- unit_divisor=1024,
45
- ) as bar:
46
- for data in response.iter_content(chunk_size=chunk_size):
47
- size = file.write(data)
48
- bar.update(size)
49
-
50
- except Exception as e:
51
- print(f"Error: {e}")
52
- time.sleep(3)
53
- download(filename, ZH_WEIGHT_URL)
54
-
55
-
56
- class Patchilizer:
57
- """
58
- A class for converting music bars to patches and vice versa.
59
- """
60
-
61
- def __init__(self):
62
- self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
63
- self.regexPattern = f"({'|'.join(map(re.escape, self.delimiters))})"
64
- self.pad_token_id = 0
65
- self.bos_token_id = 1
66
- self.eos_token_id = 2
67
-
68
- def split_bars(self, body):
69
- """
70
- Split a body of music into individual bars.
71
- """
72
- bars = re.split(self.regexPattern, "".join(body))
73
- bars = list(filter(None, bars))
74
- # remove empty strings
75
- if bars[0] in self.delimiters:
76
- bars[1] = bars[0] + bars[1]
77
- bars = bars[1:]
78
-
79
- bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
80
- return bars
81
-
82
- def bar2patch(self, bar, patch_size=PATCH_SIZE):
83
- """
84
- Convert a bar into a patch of specified length.
85
- """
86
- patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
87
- patch = patch[:patch_size]
88
- patch += [self.pad_token_id] * (patch_size - len(patch))
89
- return patch
90
-
91
- def patch2bar(self, patch):
92
- """
93
- Convert a patch into a bar.
94
- """
95
- return "".join(
96
- chr(idx) if idx > self.eos_token_id else ""
97
- for idx in patch
98
- if idx != self.eos_token_id
99
- )
100
-
101
- def encode(
102
- self,
103
- abc_code,
104
- patch_length=PATCH_LENGTH,
105
- patch_size=PATCH_SIZE,
106
- add_special_patches=False,
107
- ):
108
- """
109
- Encode music into patches of specified length.
110
- """
111
- lines = unidecode(abc_code).split("\n")
112
- lines = list(filter(None, lines)) # remove empty lines
113
-
114
- body = ""
115
- patches = []
116
-
117
- for line in lines:
118
- if len(line) > 1 and (
119
- (line[0].isalpha() and line[1] == ":") or line.startswith("%%score")
120
- ):
121
- if body:
122
- bars = self.split_bars(body)
123
- patches.extend(
124
- self.bar2patch(
125
- bar + "\n" if idx == len(bars) - 1 else bar, patch_size
126
- )
127
- for idx, bar in enumerate(bars)
128
- )
129
- body = ""
130
-
131
- patches.append(self.bar2patch(line + "\n", patch_size))
132
-
133
- else:
134
- body += line + "\n"
135
-
136
- if body:
137
- patches.extend(
138
- self.bar2patch(bar, patch_size) for bar in self.split_bars(body)
139
- )
140
-
141
- if add_special_patches:
142
- bos_patch = [self.bos_token_id] * (patch_size - 1) + [self.eos_token_id]
143
- eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size - 1)
144
- patches = [bos_patch] + patches + [eos_patch]
145
-
146
- return patches[:patch_length]
147
-
148
- def decode(self, patches):
149
- """
150
- Decode patches into music.
151
- """
152
- return "".join(self.patch2bar(patch) for patch in patches)
153
-
154
-
155
- class PatchLevelDecoder(PreTrainedModel):
156
- """
157
- An Patch-level Decoder model for generating patch features in an auto-regressive manner.
158
- It inherits PreTrainedModel from transformers.
159
- """
160
-
161
- def __init__(self, config):
162
- super().__init__(config)
163
- self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
164
- torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
165
- self.base = GPT2Model(config)
166
-
167
- def forward(self, patches: torch.Tensor) -> torch.Tensor:
168
- """
169
- The forward pass of the patch-level decoder model.
170
- :param patches: the patches to be encoded
171
- :return: the encoded patches
172
- """
173
- patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
174
- patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
175
- patches = self.patch_embedding(patches.to(self.device))
176
-
177
- return self.base(inputs_embeds=patches)
178
-
179
-
180
- class CharLevelDecoder(PreTrainedModel):
181
- """
182
- A Char-level Decoder model for generating the characters within each bar patch sequentially.
183
- It inherits PreTrainedModel from transformers.
184
- """
185
-
186
- def __init__(self, config):
187
- super().__init__(config)
188
- self.pad_token_id = 0
189
- self.bos_token_id = 1
190
- self.eos_token_id = 2
191
- self.base = GPT2LMHeadModel(config)
192
-
193
- def forward(
194
- self,
195
- encoded_patches: torch.Tensor,
196
- target_patches: torch.Tensor,
197
- patch_sampling_batch_size: int,
198
- ):
199
- """
200
- The forward pass of the char-level decoder model.
201
- :param encoded_patches: the encoded patches
202
- :param target_patches: the target patches
203
- :return: the decoded patches
204
- """
205
- # preparing the labels for model training
206
- target_masks = target_patches == self.pad_token_id
207
- labels = target_patches.clone().masked_fill_(target_masks, -100)
208
-
209
- # masking the labels for model training
210
- target_masks = torch.ones_like(labels)
211
- target_masks = target_masks.masked_fill_(labels == -100, 0)
212
-
213
- # select patches
214
- if (
215
- patch_sampling_batch_size != 0
216
- and patch_sampling_batch_size < target_patches.shape[0]
217
- ):
218
- indices = list(range(len(target_patches)))
219
- random.shuffle(indices)
220
- selected_indices = sorted(indices[:patch_sampling_batch_size])
221
-
222
- target_patches = target_patches[selected_indices, :]
223
- target_masks = target_masks[selected_indices, :]
224
- encoded_patches = encoded_patches[selected_indices, :]
225
- labels = labels[selected_indices, :]
226
-
227
- # get input embeddings
228
- inputs_embeds = torch.nn.functional.embedding(
229
- target_patches, self.base.transformer.wte.weight
230
- )
231
-
232
- # concatenate the encoded patches with the input embeddings
233
- inputs_embeds = torch.cat(
234
- (encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1
235
- )
236
-
237
- return self.base(
238
- inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=labels
239
- )
240
-
241
- def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
242
- """
243
- The generate function for generating a patch based on the encoded patch and already generated tokens.
244
- :param encoded_patch: the encoded patch
245
- :param tokens: already generated tokens in the patch
246
- :return: the probability distribution of next token
247
- """
248
- encoded_patch = encoded_patch.reshape(1, 1, -1)
249
- tokens = tokens.reshape(1, -1)
250
-
251
- # Get input embeddings
252
- tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
253
-
254
- # Concatenate the encoded patch with the input embeddings
255
- tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
256
-
257
- # Get output from model
258
- outputs = self.base(inputs_embeds=tokens)
259
-
260
- # Get probabilities of next token
261
- probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
262
-
263
- return probs
264
-
265
-
266
- class TunesFormer(PreTrainedModel):
267
- """
268
- TunesFormer is a hierarchical music generation model based on bar patching.
269
- It includes a patch-level decoder and a character-level decoder.
270
- It inherits PreTrainedModel from transformers.
271
- """
272
-
273
- def __init__(self, encoder_config, decoder_config, share_weights=False):
274
- super().__init__(encoder_config)
275
- self.pad_token_id = 0
276
- self.bos_token_id = 1
277
- self.eos_token_id = 2
278
- if share_weights:
279
- max_layers = max(
280
- encoder_config.num_hidden_layers, decoder_config.num_hidden_layers
281
- )
282
-
283
- max_context_size = max(encoder_config.max_length, decoder_config.max_length)
284
-
285
- max_position_embeddings = max(
286
- encoder_config.max_position_embeddings,
287
- decoder_config.max_position_embeddings,
288
- )
289
-
290
- encoder_config.num_hidden_layers = max_layers
291
- encoder_config.max_length = max_context_size
292
- encoder_config.max_position_embeddings = max_position_embeddings
293
- decoder_config.num_hidden_layers = max_layers
294
- decoder_config.max_length = max_context_size
295
- decoder_config.max_position_embeddings = max_position_embeddings
296
-
297
- self.patch_level_decoder = PatchLevelDecoder(encoder_config)
298
- self.char_level_decoder = CharLevelDecoder(decoder_config)
299
-
300
- if share_weights:
301
- self.patch_level_decoder.base = self.char_level_decoder.base.transformer
302
-
303
- def forward(
304
- self,
305
- patches: torch.Tensor,
306
- patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE,
307
- ):
308
- """
309
- The forward pass of the TunesFormer model.
310
- :param patches: the patches to be both encoded and decoded
311
- :return: the decoded patches
312
- """
313
- patches = patches.reshape(len(patches), -1, PATCH_SIZE)
314
- encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
315
-
316
- return self.char_level_decoder(
317
- encoded_patches.squeeze(0)[:-1, :],
318
- patches.squeeze(0)[1:, :],
319
- patch_sampling_batch_size,
320
- )
321
-
322
- def generate(
323
- self,
324
- patches: torch.Tensor,
325
- tokens: torch.Tensor,
326
- top_p: float = 1,
327
- top_k: int = 0,
328
- temperature: float = 1,
329
- seed: int = None,
330
- ):
331
- """
332
- The generate function for generating patches based on patches.
333
- :param patches: the patches to be encoded
334
- :return: the generated patches
335
- """
336
- patches = patches.reshape(len(patches), -1, PATCH_SIZE)
337
- encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
338
-
339
- if tokens == None:
340
- tokens = torch.tensor([self.bos_token_id], device=self.device)
341
-
342
- generated_patch = []
343
- random.seed(seed)
344
-
345
- while True:
346
- if seed != None:
347
- n_seed = random.randint(0, 1000000)
348
- random.seed(n_seed)
349
-
350
- else:
351
- n_seed = None
352
-
353
- prob = (
354
- self.char_level_decoder.generate(encoded_patches[0][-1], tokens)
355
- .cpu()
356
- .detach()
357
- .numpy()
358
- )
359
-
360
- prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
361
- prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
362
-
363
- token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
364
-
365
- generated_patch.append(token)
366
- if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
367
- break
368
-
369
- else:
370
- tokens = torch.cat(
371
- (tokens, torch.tensor([token], device=self.device)), dim=0
372
- )
373
-
374
- return generated_patch, n_seed
375
-
376
-
377
- class PatchilizedData(Dataset):
378
- def __init__(self, items, patchilizer):
379
- self.texts = []
380
-
381
- for item in tqdm(items):
382
- text = item["control code"] + "\n".join(
383
- item["abc notation"].split("\n")[1:]
384
- )
385
- input_patch = patchilizer.encode(text, add_special_patches=True)
386
- input_patch = torch.tensor(input_patch)
387
- if torch.sum(input_patch) != 0:
388
- self.texts.append(input_patch)
389
-
390
- def __len__(self):
391
- return len(self.texts)
392
-
393
- def __getitem__(self, idx):
394
- return self.texts[idx]
 
1
+ import re
2
+ import time
3
+ import torch
4
+ import random
5
+ import requests
6
+ from config import *
7
+ from tqdm import tqdm
8
+ from unidecode import unidecode
9
+ from torch.utils.data import Dataset
10
+ from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel
11
+ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
12
+
13
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
14
+
15
+
16
+ def download(filename: str, url: str):
17
+ try:
18
+ response = requests.get(url, stream=True)
19
+ total_size = int(response.headers.get("content-length", 0))
20
+ chunk_size = 1024
21
+
22
+ with open(filename, "wb") as file, tqdm(
23
+ desc=f"Downloading {filename} from '{url}'...",
24
+ total=total_size,
25
+ unit="B",
26
+ unit_scale=True,
27
+ unit_divisor=1024,
28
+ ) as bar:
29
+ for data in response.iter_content(chunk_size=chunk_size):
30
+ size = file.write(data)
31
+ bar.update(size)
32
+
33
+ except Exception as e:
34
+ print(f"Error: {e}, retrying...")
35
+ time.sleep(10)
36
+ download(filename, url)
37
+
38
+
39
+ class Patchilizer:
40
+ """
41
+ A class for converting music bars to patches and vice versa.
42
+ """
43
+
44
+ def __init__(self):
45
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
46
+ self.regexPattern = f"({'|'.join(map(re.escape, self.delimiters))})"
47
+ self.pad_token_id = 0
48
+ self.bos_token_id = 1
49
+ self.eos_token_id = 2
50
+
51
+ def split_bars(self, body):
52
+ """
53
+ Split a body of music into individual bars.
54
+ """
55
+ bars = re.split(self.regexPattern, "".join(body))
56
+ bars = list(filter(None, bars))
57
+ # remove empty strings
58
+ if bars[0] in self.delimiters:
59
+ bars[1] = bars[0] + bars[1]
60
+ bars = bars[1:]
61
+
62
+ bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
63
+ return bars
64
+
65
+ def bar2patch(self, bar, patch_size=PATCH_SIZE):
66
+ """
67
+ Convert a bar into a patch of specified length.
68
+ """
69
+ patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
70
+ patch = patch[:patch_size]
71
+ patch += [self.pad_token_id] * (patch_size - len(patch))
72
+ return patch
73
+
74
+ def patch2bar(self, patch):
75
+ """
76
+ Convert a patch into a bar.
77
+ """
78
+ return "".join(
79
+ chr(idx) if idx > self.eos_token_id else ""
80
+ for idx in patch
81
+ if idx != self.eos_token_id
82
+ )
83
+
84
+ def encode(
85
+ self,
86
+ abc_code,
87
+ patch_length=PATCH_LENGTH,
88
+ patch_size=PATCH_SIZE,
89
+ add_special_patches=False,
90
+ ):
91
+ """
92
+ Encode music into patches of specified length.
93
+ """
94
+ lines = unidecode(abc_code).split("\n")
95
+ lines = list(filter(None, lines)) # remove empty lines
96
+
97
+ body = ""
98
+ patches = []
99
+
100
+ for line in lines:
101
+ if len(line) > 1 and (
102
+ (line[0].isalpha() and line[1] == ":") or line.startswith("%%score")
103
+ ):
104
+ if body:
105
+ bars = self.split_bars(body)
106
+ patches.extend(
107
+ self.bar2patch(
108
+ bar + "\n" if idx == len(bars) - 1 else bar, patch_size
109
+ )
110
+ for idx, bar in enumerate(bars)
111
+ )
112
+ body = ""
113
+
114
+ patches.append(self.bar2patch(line + "\n", patch_size))
115
+
116
+ else:
117
+ body += line + "\n"
118
+
119
+ if body:
120
+ patches.extend(
121
+ self.bar2patch(bar, patch_size) for bar in self.split_bars(body)
122
+ )
123
+
124
+ if add_special_patches:
125
+ bos_patch = [self.bos_token_id] * (patch_size - 1) + [self.eos_token_id]
126
+ eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size - 1)
127
+ patches = [bos_patch] + patches + [eos_patch]
128
+
129
+ return patches[:patch_length]
130
+
131
+ def decode(self, patches):
132
+ """
133
+ Decode patches into music.
134
+ """
135
+ return "".join(self.patch2bar(patch) for patch in patches)
136
+
137
+
138
+ class PatchLevelDecoder(PreTrainedModel):
139
+ """
140
+ An Patch-level Decoder model for generating patch features in an auto-regressive manner.
141
+ It inherits PreTrainedModel from transformers.
142
+ """
143
+
144
+ def __init__(self, config):
145
+ super().__init__(config)
146
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
147
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
148
+ self.base = GPT2Model(config)
149
+
150
+ def forward(self, patches: torch.Tensor) -> torch.Tensor:
151
+ """
152
+ The forward pass of the patch-level decoder model.
153
+ :param patches: the patches to be encoded
154
+ :return: the encoded patches
155
+ """
156
+ patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
157
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
158
+ patches = self.patch_embedding(patches.to(self.device))
159
+
160
+ return self.base(inputs_embeds=patches)
161
+
162
+
163
+ class CharLevelDecoder(PreTrainedModel):
164
+ """
165
+ A Char-level Decoder model for generating the characters within each bar patch sequentially.
166
+ It inherits PreTrainedModel from transformers.
167
+ """
168
+
169
+ def __init__(self, config):
170
+ super().__init__(config)
171
+ self.pad_token_id = 0
172
+ self.bos_token_id = 1
173
+ self.eos_token_id = 2
174
+ self.base = GPT2LMHeadModel(config)
175
+
176
+ def forward(
177
+ self,
178
+ encoded_patches: torch.Tensor,
179
+ target_patches: torch.Tensor,
180
+ patch_sampling_batch_size: int,
181
+ ):
182
+ """
183
+ The forward pass of the char-level decoder model.
184
+ :param encoded_patches: the encoded patches
185
+ :param target_patches: the target patches
186
+ :return: the decoded patches
187
+ """
188
+ # preparing the labels for model training
189
+ target_masks = target_patches == self.pad_token_id
190
+ labels = target_patches.clone().masked_fill_(target_masks, -100)
191
+
192
+ # masking the labels for model training
193
+ target_masks = torch.ones_like(labels)
194
+ target_masks = target_masks.masked_fill_(labels == -100, 0)
195
+
196
+ # select patches
197
+ if (
198
+ patch_sampling_batch_size != 0
199
+ and patch_sampling_batch_size < target_patches.shape[0]
200
+ ):
201
+ indices = list(range(len(target_patches)))
202
+ random.shuffle(indices)
203
+ selected_indices = sorted(indices[:patch_sampling_batch_size])
204
+
205
+ target_patches = target_patches[selected_indices, :]
206
+ target_masks = target_masks[selected_indices, :]
207
+ encoded_patches = encoded_patches[selected_indices, :]
208
+ labels = labels[selected_indices, :]
209
+
210
+ # get input embeddings
211
+ inputs_embeds = torch.nn.functional.embedding(
212
+ target_patches, self.base.transformer.wte.weight
213
+ )
214
+
215
+ # concatenate the encoded patches with the input embeddings
216
+ inputs_embeds = torch.cat(
217
+ (encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1
218
+ )
219
+
220
+ return self.base(
221
+ inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=labels
222
+ )
223
+
224
+ def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
225
+ """
226
+ The generate function for generating a patch based on the encoded patch and already generated tokens.
227
+ :param encoded_patch: the encoded patch
228
+ :param tokens: already generated tokens in the patch
229
+ :return: the probability distribution of next token
230
+ """
231
+ encoded_patch = encoded_patch.reshape(1, 1, -1)
232
+ tokens = tokens.reshape(1, -1)
233
+
234
+ # Get input embeddings
235
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
236
+
237
+ # Concatenate the encoded patch with the input embeddings
238
+ tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
239
+
240
+ # Get output from model
241
+ outputs = self.base(inputs_embeds=tokens)
242
+
243
+ # Get probabilities of next token
244
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
245
+
246
+ return probs
247
+
248
+
249
+ class TunesFormer(PreTrainedModel):
250
+ """
251
+ TunesFormer is a hierarchical music generation model based on bar patching.
252
+ It includes a patch-level decoder and a character-level decoder.
253
+ It inherits PreTrainedModel from transformers.
254
+ """
255
+
256
+ def __init__(self, encoder_config, decoder_config, share_weights=False):
257
+ super().__init__(encoder_config)
258
+ self.pad_token_id = 0
259
+ self.bos_token_id = 1
260
+ self.eos_token_id = 2
261
+ if share_weights:
262
+ max_layers = max(
263
+ encoder_config.num_hidden_layers, decoder_config.num_hidden_layers
264
+ )
265
+
266
+ max_context_size = max(encoder_config.max_length, decoder_config.max_length)
267
+
268
+ max_position_embeddings = max(
269
+ encoder_config.max_position_embeddings,
270
+ decoder_config.max_position_embeddings,
271
+ )
272
+
273
+ encoder_config.num_hidden_layers = max_layers
274
+ encoder_config.max_length = max_context_size
275
+ encoder_config.max_position_embeddings = max_position_embeddings
276
+ decoder_config.num_hidden_layers = max_layers
277
+ decoder_config.max_length = max_context_size
278
+ decoder_config.max_position_embeddings = max_position_embeddings
279
+
280
+ self.patch_level_decoder = PatchLevelDecoder(encoder_config)
281
+ self.char_level_decoder = CharLevelDecoder(decoder_config)
282
+
283
+ if share_weights:
284
+ self.patch_level_decoder.base = self.char_level_decoder.base.transformer
285
+
286
+ def forward(
287
+ self,
288
+ patches: torch.Tensor,
289
+ patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE,
290
+ ):
291
+ """
292
+ The forward pass of the TunesFormer model.
293
+ :param patches: the patches to be both encoded and decoded
294
+ :return: the decoded patches
295
+ """
296
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
297
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
298
+
299
+ return self.char_level_decoder(
300
+ encoded_patches.squeeze(0)[:-1, :],
301
+ patches.squeeze(0)[1:, :],
302
+ patch_sampling_batch_size,
303
+ )
304
+
305
+ def generate(
306
+ self,
307
+ patches: torch.Tensor,
308
+ tokens: torch.Tensor,
309
+ top_p: float = 1,
310
+ top_k: int = 0,
311
+ temperature: float = 1,
312
+ seed: int = None,
313
+ ):
314
+ """
315
+ The generate function for generating patches based on patches.
316
+ :param patches: the patches to be encoded
317
+ :return: the generated patches
318
+ """
319
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
320
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
321
+
322
+ if tokens == None:
323
+ tokens = torch.tensor([self.bos_token_id], device=self.device)
324
+
325
+ generated_patch = []
326
+ random.seed(seed)
327
+
328
+ while True:
329
+ if seed != None:
330
+ n_seed = random.randint(0, 1000000)
331
+ random.seed(n_seed)
332
+
333
+ else:
334
+ n_seed = None
335
+
336
+ prob = (
337
+ self.char_level_decoder.generate(encoded_patches[0][-1], tokens)
338
+ .cpu()
339
+ .detach()
340
+ .numpy()
341
+ )
342
+
343
+ prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
344
+ prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
345
+
346
+ token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
347
+
348
+ generated_patch.append(token)
349
+ if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
350
+ break
351
+
352
+ else:
353
+ tokens = torch.cat(
354
+ (tokens, torch.tensor([token], device=self.device)), dim=0
355
+ )
356
+
357
+ return generated_patch, n_seed
358
+
359
+
360
+ class PatchilizedData(Dataset):
361
+ def __init__(self, items, patchilizer):
362
+ self.texts = []
363
+
364
+ for item in tqdm(items):
365
+ text = item["control code"] + "\n".join(
366
+ item["abc notation"].split("\n")[1:]
367
+ )
368
+ input_patch = patchilizer.encode(text, add_special_patches=True)
369
+ input_patch = torch.tensor(input_patch)
370
+ if torch.sum(input_patch) != 0:
371
+ self.texts.append(input_patch)
372
+
373
+ def __len__(self):
374
+ return len(self.texts)
375
+
376
+ def __getitem__(self, idx):
377
+ return self.texts[idx]