admin commited on
Commit
f405cb8
1 Parent(s): c84fad0
Files changed (8) hide show
  1. .gitattributes +12 -11
  2. .gitignore +8 -0
  3. README.md +4 -4
  4. app.py +243 -0
  5. convert.py +61 -0
  6. model.py +325 -0
  7. requirements.txt +9 -0
  8. utils.py +70 -0
.gitattributes CHANGED
@@ -1,35 +1,36 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.sf3 filter=lfs diff=lfs merge=lfs -text
36
+ *.AppImage filter=lfs diff=lfs merge=lfs -textlibnss3.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/*
2
+ output/*
3
+ rename.sh
4
+ test.py
5
+ gpt2-abcmusic/*
6
+ *.pth
7
+ flagged/*
8
+ mscore3/*
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: HoyoGPT
3
- emoji: 🐢
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
  title: HoyoGPT
3
+ emoji: 🎹
4
+ colorFrom: green
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.39.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import time
4
+ import torch
5
+ import shutil
6
+ import argparse
7
+ import warnings
8
+ import gradio as gr
9
+ from transformers import GPT2Config
10
+ from model import Patchilizer, TunesFormer
11
+ from convert import abc2xml, xml2, xml2img
12
+ from utils import (
13
+ PATCH_NUM_LAYERS,
14
+ PATCH_LENGTH,
15
+ CHAR_NUM_LAYERS,
16
+ PATCH_SIZE,
17
+ SHARE_WEIGHTS,
18
+ WEIGHTS_PATH,
19
+ TEMP_DIR,
20
+ TEYVAT,
21
+ DEVICE,
22
+ )
23
+
24
+
25
+ def get_args(parser: argparse.ArgumentParser):
26
+ parser.add_argument(
27
+ "-num_tunes",
28
+ type=int,
29
+ default=1,
30
+ help="the number of independently computed returned tunes",
31
+ )
32
+ parser.add_argument(
33
+ "-max_patch",
34
+ type=int,
35
+ default=128,
36
+ help="integer to define the maximum length in tokens of each tune",
37
+ )
38
+ parser.add_argument(
39
+ "-top_p",
40
+ type=float,
41
+ default=0.8,
42
+ help="float to define the tokens that are within the sample operation of text generation",
43
+ )
44
+ parser.add_argument(
45
+ "-top_k",
46
+ type=int,
47
+ default=8,
48
+ help="integer to define the tokens that are within the sample operation of text generation",
49
+ )
50
+ parser.add_argument(
51
+ "-temperature",
52
+ type=float,
53
+ default=1.2,
54
+ help="the temperature of the sampling operation",
55
+ )
56
+ parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
57
+ parser.add_argument(
58
+ "-show_control_code",
59
+ type=bool,
60
+ default=False,
61
+ help="whether to show control code",
62
+ )
63
+ return parser.parse_args()
64
+
65
+
66
+ def generate_music(args, region: str):
67
+ patchilizer = Patchilizer()
68
+ patch_config = GPT2Config(
69
+ num_hidden_layers=PATCH_NUM_LAYERS,
70
+ max_length=PATCH_LENGTH,
71
+ max_position_embeddings=PATCH_LENGTH,
72
+ vocab_size=1,
73
+ )
74
+ char_config = GPT2Config(
75
+ num_hidden_layers=CHAR_NUM_LAYERS,
76
+ max_length=PATCH_SIZE,
77
+ max_position_embeddings=PATCH_SIZE,
78
+ vocab_size=128,
79
+ )
80
+ model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
81
+ checkpoint = torch.load(WEIGHTS_PATH, map_location=torch.device("cpu"))
82
+ model.load_state_dict(checkpoint["model"])
83
+ model = model.to(DEVICE)
84
+ model.eval()
85
+ prompt = f"A:{region}\n"
86
+ tunes = ""
87
+ num_tunes = args.num_tunes
88
+ max_patch = args.max_patch
89
+ top_p = args.top_p
90
+ top_k = args.top_k
91
+ temperature = args.temperature
92
+ seed = args.seed
93
+ show_control_code = args.show_control_code
94
+ print(" Hyper parms ".center(60, "#"), "\n")
95
+ arg_dict: dict = vars(args)
96
+ for key in arg_dict.keys():
97
+ print(f"{key}: {str(arg_dict[key])}")
98
+
99
+ print("\n", " Output tunes ".center(60, "#"))
100
+ start_time = time.time()
101
+ for i in range(num_tunes):
102
+ title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
103
+ tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
104
+ lines = re.split(r"(\n)", tune)
105
+ tune = ""
106
+ skip = False
107
+ for line in lines:
108
+ if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
109
+ if not skip:
110
+ print(line, end="")
111
+ tune += line
112
+
113
+ skip = False
114
+
115
+ else:
116
+ skip = True
117
+
118
+ input_patches = torch.tensor(
119
+ [patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=DEVICE
120
+ )
121
+
122
+ if tune == "":
123
+ tokens = None
124
+
125
+ else:
126
+ prefix = patchilizer.decode(input_patches[0])
127
+ remaining_tokens = prompt[len(prefix) :]
128
+ tokens = torch.tensor(
129
+ [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
130
+ device=DEVICE,
131
+ )
132
+
133
+ while input_patches.shape[1] < max_patch:
134
+ predicted_patch, seed = model.generate(
135
+ input_patches,
136
+ tokens,
137
+ top_p=top_p,
138
+ top_k=top_k,
139
+ temperature=temperature,
140
+ seed=seed,
141
+ )
142
+ tokens = None
143
+ if predicted_patch[0] != patchilizer.eos_token_id:
144
+ next_bar = patchilizer.decode([predicted_patch])
145
+ if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
146
+ print(next_bar, end="")
147
+ tune += next_bar
148
+
149
+ if next_bar == "":
150
+ break
151
+
152
+ next_bar = remaining_tokens + next_bar
153
+ remaining_tokens = ""
154
+ predicted_patch = torch.tensor(
155
+ patchilizer.bar2patch(next_bar), device=DEVICE
156
+ ).unsqueeze(0)
157
+ input_patches = torch.cat(
158
+ [input_patches, predicted_patch.unsqueeze(0)], dim=1
159
+ )
160
+
161
+ else:
162
+ break
163
+
164
+ tunes += f"{tune}\n\n"
165
+ print("\n")
166
+
167
+ print("Generation time: {:.2f} seconds".format(time.time() - start_time))
168
+ timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
169
+ try:
170
+ xml = abc2xml(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
171
+ midi = xml2(xml, "mid")
172
+ audio = xml2(xml, "wav")
173
+ pdf, jpg = xml2img(xml)
174
+ mxl = xml2(xml, "mxl")
175
+ return tunes, midi, pdf, xml, mxl, jpg, audio
176
+
177
+ except Exception as e:
178
+ print(f"Invalid abc generated: {e}, retrying...")
179
+ return generate_music(args, region)
180
+
181
+
182
+ def infer(region: str):
183
+ if os.path.exists(TEMP_DIR):
184
+ shutil.rmtree(TEMP_DIR)
185
+
186
+ os.makedirs(TEMP_DIR, exist_ok=True)
187
+ parser = argparse.ArgumentParser()
188
+ args = get_args(parser)
189
+ return generate_music(args, TEYVAT[region])
190
+
191
+
192
+ if __name__ == "__main__":
193
+ warnings.filterwarnings("ignore")
194
+ with gr.Blocks() as demo:
195
+ gr.Markdown(
196
+ """
197
+ <center>欢迎使用此创空间, 此创空间由bilibili <a href="https://space.bilibili.com/30620472">@亦真亦幻Studio</a> 基于 Tunesformer 开源项目制作,完全免费。</center>
198
+ <center>Welcome to this space made by bilibili <a href="https://space.bilibili.com/30620472">@MuGeminorum</a> based on the Tunesformer open source project, which is totally free!</center>"""
199
+ )
200
+ with gr.Row():
201
+ with gr.Column():
202
+ region_opt = gr.Dropdown(
203
+ choices=list(TEYVAT.keys()),
204
+ value="蒙德 Mondstadt",
205
+ label="地区风格 Region",
206
+ )
207
+ gen_btn = gr.Button("生成 Generate")
208
+ gr.Markdown(
209
+ """
210
+ <center>
211
+ 当前模型还在调试中,计划在原神主线杀青后,所有国家地区角色全部开放后,二创音乐会齐全且样本均衡,届时重新微调模型并添加现实风格筛选辅助游戏各国家输出强化学习,以提升输出区分度与质量。<br>The current model is still in debugging, the plan is in the Genshin Impact after the main line is killed, all countries and regions after all the characters are open, the second creation of the concert will be complete and the sample is balanced, at that time to re-fine-tune the model and add the reality of the style of screening to assist in the game of each country's output to strengthen the learning in order to enhance the output differentiation and quality.
212
+
213
+ 数据来源 (Data source): <a href="https://musescore.org">MuseScore</a><br>
214
+ Tag 嵌入数据来源 (Tags source): <a href="https://genshin-impact.fandom.com/wiki/Genshin_Impact_Wiki">Genshin Impact Wiki | Fandom</a><br>
215
+ 模型基础 (Model base): <a href="https://github.com/sander-wood/tunesformer">Tunesformer</a>
216
+
217
+ 注:崩铁方面数据工程正在运作中,未来也希望随主线杀青而基线化。<br>Note: Data engineering on the Star Rail is in operation, and will hopefully be baselined in the future as well with the mainline kill.</center>"""
218
+ )
219
+
220
+ with gr.Column():
221
+ wav_output = gr.Audio(label="音频 (Audio)", type="filepath")
222
+ dld_midi = gr.components.File(label="下载 MIDI (Download MIDI)")
223
+ pdf_score = gr.components.File(label="下载 PDF 乐谱 (Download PDF)")
224
+ dld_xml = gr.components.File(label="下载 MusicXML (Download MusicXML)")
225
+ dld_mxl = gr.components.File(label="下载 MXL (Download MXL)")
226
+ abc_output = gr.Textbox(label="abc notation", show_copy_button=True)
227
+ img_score = gr.Image(label="五线谱 (Staff)", type="filepath")
228
+
229
+ gen_btn.click(
230
+ infer,
231
+ inputs=region_opt,
232
+ outputs=[
233
+ abc_output,
234
+ dld_midi,
235
+ pdf_score,
236
+ dld_xml,
237
+ dld_mxl,
238
+ img_score,
239
+ wav_output,
240
+ ],
241
+ )
242
+
243
+ demo.launch()
convert.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fitz
3
+ import subprocess
4
+ from PIL import Image
5
+ from music21 import converter
6
+ from utils import MSCORE
7
+
8
+
9
+ def abc2xml(abc_content, output_xml_path):
10
+ score = converter.parse(abc_content, format="abc")
11
+ score.write("musicxml", fp=output_xml_path, encoding="utf-8")
12
+ return output_xml_path
13
+
14
+
15
+ def xml2(xml_path: str, target_fmt: str):
16
+ src_fmt = os.path.basename(xml_path).split(".")[-1]
17
+ if not "." in target_fmt:
18
+ target_fmt = "." + target_fmt
19
+
20
+ target_file = xml_path.replace(f".{src_fmt}", target_fmt)
21
+ command = [MSCORE, "-o", target_file, xml_path]
22
+ result = subprocess.run(command)
23
+ print(result)
24
+ return target_file
25
+
26
+
27
+ def pdf2img(pdf_path: str):
28
+ output_path = pdf_path.replace(".pdf", ".jpg")
29
+ doc = fitz.open(pdf_path)
30
+ # 创建一个图像列表
31
+ images = []
32
+ for page_number in range(doc.page_count):
33
+ page = doc[page_number]
34
+ # 将页面渲染为图像
35
+ image = page.get_pixmap()
36
+ # 将图像添加到列表
37
+ images.append(
38
+ Image.frombytes("RGB", [image.width, image.height], image.samples)
39
+ )
40
+ # 竖向合并图像
41
+ merged_image = Image.new(
42
+ "RGB", (images[0].width, sum(image.height for image in images))
43
+ )
44
+ y_offset = 0
45
+ for image in images:
46
+ merged_image.paste(image, (0, y_offset))
47
+ y_offset += image.height
48
+ # 保存合并后的图像为JPG
49
+ merged_image.save(output_path, "JPEG")
50
+ # 关闭PDF文档
51
+ doc.close()
52
+ return output_path
53
+
54
+
55
+ def xml2img(xml_file: str):
56
+ ext = os.path.basename(xml_file).split(".")[-1]
57
+ pdf_score = xml_file.replace(f".{ext}", ".pdf")
58
+ command = [MSCORE, "-o", pdf_score, xml_file]
59
+ result = subprocess.run(command)
60
+ print(result)
61
+ return pdf_score, pdf2img(pdf_score)
model.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import random
4
+ from tqdm import tqdm
5
+ from unidecode import unidecode
6
+ from torch.utils.data import Dataset
7
+ from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel
8
+ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
9
+ from utils import PATCH_SIZE, PATCH_LENGTH, PATCH_SAMPLING_BATCH_SIZE
10
+
11
+
12
+ class Patchilizer:
13
+ """
14
+ A class for converting music bars to patches and vice versa.
15
+ """
16
+
17
+ def __init__(self):
18
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
19
+ self.regexPattern = f"({'|'.join(map(re.escape, self.delimiters))})"
20
+ self.pad_token_id = 0
21
+ self.bos_token_id = 1
22
+ self.eos_token_id = 2
23
+
24
+ def split_bars(self, body):
25
+ """
26
+ Split a body of music into individual bars.
27
+ """
28
+ bars = re.split(self.regexPattern, "".join(body))
29
+ bars = list(filter(None, bars))
30
+ # remove empty strings
31
+ if bars[0] in self.delimiters:
32
+ bars[1] = bars[0] + bars[1]
33
+ bars = bars[1:]
34
+
35
+ bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
36
+ return bars
37
+
38
+ def bar2patch(self, bar, patch_size=PATCH_SIZE):
39
+ """
40
+ Convert a bar into a patch of specified length.
41
+ """
42
+ patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
43
+ patch = patch[:patch_size]
44
+ patch += [self.pad_token_id] * (patch_size - len(patch))
45
+ return patch
46
+
47
+ def patch2bar(self, patch):
48
+ """
49
+ Convert a patch into a bar.
50
+ """
51
+ return "".join(
52
+ chr(idx) if idx > self.eos_token_id else ""
53
+ for idx in patch
54
+ if idx != self.eos_token_id
55
+ )
56
+
57
+ def encode(
58
+ self,
59
+ abc_code,
60
+ patch_length=PATCH_LENGTH,
61
+ patch_size=PATCH_SIZE,
62
+ add_special_patches=False,
63
+ ):
64
+ """
65
+ Encode music into patches of specified length.
66
+ """
67
+ lines = unidecode(abc_code).split("\n")
68
+ lines = list(filter(None, lines)) # remove empty lines
69
+ body = ""
70
+ patches = []
71
+ for line in lines:
72
+ if len(line) > 1 and (
73
+ (line[0].isalpha() and line[1] == ":") or line.startswith("%%score")
74
+ ):
75
+ if body:
76
+ bars = self.split_bars(body)
77
+ patches.extend(
78
+ self.bar2patch(
79
+ bar + "\n" if idx == len(bars) - 1 else bar, patch_size
80
+ )
81
+ for idx, bar in enumerate(bars)
82
+ )
83
+ body = ""
84
+
85
+ patches.append(self.bar2patch(line + "\n", patch_size))
86
+
87
+ else:
88
+ body += line + "\n"
89
+
90
+ if body:
91
+ patches.extend(
92
+ self.bar2patch(bar, patch_size) for bar in self.split_bars(body)
93
+ )
94
+
95
+ if add_special_patches:
96
+ bos_patch = [self.bos_token_id] * (patch_size - 1) + [self.eos_token_id]
97
+ eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size - 1)
98
+ patches = [bos_patch] + patches + [eos_patch]
99
+
100
+ return patches[:patch_length]
101
+
102
+ def decode(self, patches):
103
+ """
104
+ Decode patches into music.
105
+ """
106
+ return "".join(self.patch2bar(patch) for patch in patches)
107
+
108
+
109
+ class PatchLevelDecoder(PreTrainedModel):
110
+ """
111
+ An Patch-level Decoder model for generating patch features in an auto-regressive manner.
112
+ It inherits PreTrainedModel from transformers.
113
+ """
114
+
115
+ def __init__(self, config):
116
+ super().__init__(config)
117
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
118
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
119
+ self.base = GPT2Model(config)
120
+
121
+ def forward(self, patches: torch.Tensor) -> torch.Tensor:
122
+ """
123
+ The forward pass of the patch-level decoder model.
124
+ :param patches: the patches to be encoded
125
+ :return: the encoded patches
126
+ """
127
+ patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
128
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
129
+ patches = self.patch_embedding(patches.to(self.device))
130
+ return self.base(inputs_embeds=patches)
131
+
132
+
133
+ class CharLevelDecoder(PreTrainedModel):
134
+ """
135
+ A Char-level Decoder model for generating the characters within each bar patch sequentially.
136
+ It inherits PreTrainedModel from transformers.
137
+ """
138
+
139
+ def __init__(self, config):
140
+ super().__init__(config)
141
+ self.pad_token_id = 0
142
+ self.bos_token_id = 1
143
+ self.eos_token_id = 2
144
+ self.base = GPT2LMHeadModel(config)
145
+
146
+ def forward(
147
+ self,
148
+ encoded_patches: torch.Tensor,
149
+ target_patches: torch.Tensor,
150
+ patch_sampling_batch_size: int,
151
+ ):
152
+ """
153
+ The forward pass of the char-level decoder model.
154
+ :param encoded_patches: the encoded patches
155
+ :param target_patches: the target patches
156
+ :return: the decoded patches
157
+ """
158
+ # preparing the labels for model training
159
+ target_masks = target_patches == self.pad_token_id
160
+ labels = target_patches.clone().masked_fill_(target_masks, -100)
161
+ # masking the labels for model training
162
+ target_masks = torch.ones_like(labels)
163
+ target_masks = target_masks.masked_fill_(labels == -100, 0)
164
+ # select patches
165
+ if (
166
+ patch_sampling_batch_size != 0
167
+ and patch_sampling_batch_size < target_patches.shape[0]
168
+ ):
169
+ indices = list(range(len(target_patches)))
170
+ random.shuffle(indices)
171
+ selected_indices = sorted(indices[:patch_sampling_batch_size])
172
+ target_patches = target_patches[selected_indices, :]
173
+ target_masks = target_masks[selected_indices, :]
174
+ encoded_patches = encoded_patches[selected_indices, :]
175
+ labels = labels[selected_indices, :]
176
+
177
+ # get input embeddings
178
+ inputs_embeds = torch.nn.functional.embedding(
179
+ target_patches, self.base.transformer.wte.weight
180
+ )
181
+ # concatenate the encoded patches with the input embeddings
182
+ inputs_embeds = torch.cat(
183
+ (encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1
184
+ )
185
+ return self.base(
186
+ inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=labels
187
+ )
188
+
189
+ def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
190
+ """
191
+ The generate function for generating a patch based on the encoded patch and already generated tokens.
192
+ :param encoded_patch: the encoded patch
193
+ :param tokens: already generated tokens in the patch
194
+ :return: the probability distribution of next token
195
+ """
196
+ encoded_patch = encoded_patch.reshape(1, 1, -1)
197
+ tokens = tokens.reshape(1, -1)
198
+ # Get input embeddings
199
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
200
+ # Concatenate the encoded patch with the input embeddings
201
+ tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
202
+ # Get output from model
203
+ outputs = self.base(inputs_embeds=tokens)
204
+ # Get probabilities of next token
205
+ return torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
206
+
207
+
208
+ class TunesFormer(PreTrainedModel):
209
+ """
210
+ TunesFormer is a hierarchical music generation model based on bar patching.
211
+ It includes a patch-level decoder and a character-level decoder.
212
+ It inherits PreTrainedModel from transformers.
213
+ """
214
+
215
+ def __init__(self, encoder_config, decoder_config, share_weights=False):
216
+ super().__init__(encoder_config)
217
+ self.pad_token_id = 0
218
+ self.bos_token_id = 1
219
+ self.eos_token_id = 2
220
+ if share_weights:
221
+ max_layers = max(
222
+ encoder_config.num_hidden_layers, decoder_config.num_hidden_layers
223
+ )
224
+ max_context_size = max(encoder_config.max_length, decoder_config.max_length)
225
+ max_position_embeddings = max(
226
+ encoder_config.max_position_embeddings,
227
+ decoder_config.max_position_embeddings,
228
+ )
229
+ encoder_config.num_hidden_layers = max_layers
230
+ encoder_config.max_length = max_context_size
231
+ encoder_config.max_position_embeddings = max_position_embeddings
232
+ decoder_config.num_hidden_layers = max_layers
233
+ decoder_config.max_length = max_context_size
234
+ decoder_config.max_position_embeddings = max_position_embeddings
235
+
236
+ self.patch_level_decoder = PatchLevelDecoder(encoder_config)
237
+ self.char_level_decoder = CharLevelDecoder(decoder_config)
238
+ if share_weights:
239
+ self.patch_level_decoder.base = self.char_level_decoder.base.transformer
240
+
241
+ def forward(
242
+ self,
243
+ patches: torch.Tensor,
244
+ patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE,
245
+ ):
246
+ """
247
+ The forward pass of the TunesFormer model.
248
+ :param patches: the patches to be both encoded and decoded
249
+ :return: the decoded patches
250
+ """
251
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
252
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
253
+ return self.char_level_decoder(
254
+ encoded_patches.squeeze(0)[:-1, :],
255
+ patches.squeeze(0)[1:, :],
256
+ patch_sampling_batch_size,
257
+ )
258
+
259
+ def generate(
260
+ self,
261
+ patches: torch.Tensor,
262
+ tokens: torch.Tensor,
263
+ top_p: float = 1,
264
+ top_k: int = 0,
265
+ temperature: float = 1,
266
+ seed: int = None,
267
+ ):
268
+ """
269
+ The generate function for generating patches based on patches.
270
+ :param patches: the patches to be encoded
271
+ :return: the generated patches
272
+ """
273
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
274
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
275
+ if tokens == None:
276
+ tokens = torch.tensor([self.bos_token_id], device=self.device)
277
+
278
+ generated_patch = []
279
+ random.seed(seed)
280
+ while True:
281
+ if seed != None:
282
+ n_seed = random.randint(0, 1000000)
283
+ random.seed(n_seed)
284
+
285
+ else:
286
+ n_seed = None
287
+
288
+ prob = (
289
+ self.char_level_decoder.generate(encoded_patches[0][-1], tokens)
290
+ .cpu()
291
+ .detach()
292
+ .numpy()
293
+ )
294
+ prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
295
+ prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
296
+ token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
297
+ generated_patch.append(token)
298
+ if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
299
+ break
300
+
301
+ else:
302
+ tokens = torch.cat(
303
+ (tokens, torch.tensor([token], device=self.device)), dim=0
304
+ )
305
+
306
+ return generated_patch, n_seed
307
+
308
+
309
+ class PatchilizedData(Dataset):
310
+ def __init__(self, items, patchilizer):
311
+ self.texts = []
312
+ for item in tqdm(items):
313
+ text = item["control code"] + "\n".join(
314
+ item["abc notation"].split("\n")[1:]
315
+ )
316
+ input_patch = patchilizer.encode(text, add_special_patches=True)
317
+ input_patch = torch.tensor(input_patch)
318
+ if torch.sum(input_patch) != 0:
319
+ self.texts.append(input_patch)
320
+
321
+ def __len__(self):
322
+ return len(self.texts)
323
+
324
+ def __getitem__(self, idx):
325
+ return self.texts[idx]
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.18.0
2
+ samplings==0.1.7
3
+ unidecode
4
+ music21
5
+ autopep8
6
+ pillow==9.4.0
7
+ pymupdf
8
+ torch
9
+ modelscope==1.15
utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import torch
5
+ import requests
6
+ import subprocess
7
+ from tqdm import tqdm
8
+ from modelscope import snapshot_download
9
+
10
+ TEYVAT = {
11
+ "蒙德 Mondstadt": "Mondstadt",
12
+ "璃月 Liyue": "Liyue",
13
+ "稻妻 Inazuma": "Inazuma",
14
+ "须弥 Sumeru": "Sumeru",
15
+ "枫丹 Fontaine": "Fontaine",
16
+ }
17
+ WEIGHTS_PATH = (
18
+ snapshot_download("MuGeminorum/hoyoMusic", cache_dir="./__pycache__")
19
+ + "/weights.pth"
20
+ )
21
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ TEMP_DIR = "./flagged"
23
+ PATCH_LENGTH = 128 # Patch Length
24
+ PATCH_SIZE = 32 # Patch Size
25
+ PATCH_NUM_LAYERS = 9 # Number of layers in the encoder
26
+ CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
27
+ # Batch size for patch during training, 0 for full context
28
+ PATCH_SAMPLING_BATCH_SIZE = 0
29
+ # Whether to share weights between the encoder and decoder
30
+ SHARE_WEIGHTS = False
31
+
32
+
33
+ def download(filename: str, url: str):
34
+ try:
35
+ response = requests.get(url, stream=True)
36
+ total_size = int(response.headers.get("content-length", 0))
37
+ chunk_size = 1024
38
+ with open(filename, "wb") as file, tqdm(
39
+ desc=f"Downloading {filename} from '{url}'...",
40
+ total=total_size,
41
+ unit="B",
42
+ unit_scale=True,
43
+ unit_divisor=1024,
44
+ ) as bar:
45
+ for data in response.iter_content(chunk_size=chunk_size):
46
+ size = file.write(data)
47
+ bar.update(size)
48
+
49
+ except Exception as e:
50
+ print(f"Error: {e}, retrying...")
51
+ time.sleep(10)
52
+ download(filename, url)
53
+
54
+
55
+ if sys.platform.startswith("linux"):
56
+ apkname = "MuseScore.AppImage"
57
+ extra_dir = "squashfs-root"
58
+ download(
59
+ filename=apkname,
60
+ url="https://master.dl.sourceforge.net/project/musescore.mirror/v4.2.0/MuseScore-4.2.0.233521125-x86_64.AppImage?viasf=1",
61
+ )
62
+ if not os.path.exists(extra_dir):
63
+ subprocess.run(["chmod", "+x", f"./{apkname}"])
64
+ subprocess.run([f"./{apkname}", "--appimage-extract"])
65
+
66
+ MSCORE = f"./{extra_dir}/AppRun"
67
+ os.environ["QT_QPA_PLATFORM"] = "offscreen"
68
+
69
+ else:
70
+ MSCORE = "D:/Program Files/MuseScore 3/bin/MuseScore3.exe"