Monan Zhou commited on
Commit
3829121
1 Parent(s): 6c48757

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -267
app.py CHANGED
@@ -1,267 +1,259 @@
1
- import re
2
- import os
3
- import time
4
- import torch
5
- import shutil
6
- import argparse
7
- import gradio as gr
8
- from config import *
9
- from utils import Patchilizer, TunesFormer, DEVICE
10
- from convert import abc_to_midi, abc_to_musicxml, musicxml_to_mxl, mxl2jpg, midi2wav
11
- from modelscope import snapshot_download
12
- from transformers import GPT2Config
13
- import warnings
14
-
15
- warnings.filterwarnings("ignore")
16
-
17
- # 模型下载
18
- MODEL_DIR = snapshot_download("MuGeminorum/hoyoGPT")
19
-
20
-
21
- def get_args(parser: argparse.ArgumentParser):
22
- parser.add_argument(
23
- "-num_tunes",
24
- type=int,
25
- default=1,
26
- help="the number of independently computed returned tunes",
27
- )
28
- parser.add_argument(
29
- "-max_patch",
30
- type=int,
31
- default=128,
32
- help="integer to define the maximum length in tokens of each tune",
33
- )
34
- parser.add_argument(
35
- "-top_p",
36
- type=float,
37
- default=0.8,
38
- help="float to define the tokens that are within the sample operation of text generation",
39
- )
40
- parser.add_argument(
41
- "-top_k",
42
- type=int,
43
- default=8,
44
- help="integer to define the tokens that are within the sample operation of text generation",
45
- )
46
- parser.add_argument(
47
- "-temperature",
48
- type=float,
49
- default=1.2,
50
- help="the temperature of the sampling operation",
51
- )
52
- parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
53
- parser.add_argument(
54
- "-show_control_code",
55
- type=bool,
56
- default=False,
57
- help="whether to show control code",
58
- )
59
- args = parser.parse_args()
60
-
61
- return args
62
-
63
-
64
- def generate_abc(args, epochs: str, region: str):
65
- patchilizer = Patchilizer()
66
-
67
- patch_config = GPT2Config(
68
- num_hidden_layers=PATCH_NUM_LAYERS,
69
- max_length=PATCH_LENGTH,
70
- max_position_embeddings=PATCH_LENGTH,
71
- vocab_size=1,
72
- )
73
-
74
- char_config = GPT2Config(
75
- num_hidden_layers=CHAR_NUM_LAYERS,
76
- max_length=PATCH_SIZE,
77
- max_position_embeddings=PATCH_SIZE,
78
- vocab_size=128,
79
- )
80
-
81
- model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
82
- filename = f"{MODEL_DIR}/{epochs}/weights.pth"
83
- checkpoint = torch.load(filename, map_location=torch.device("cpu"))
84
- model.load_state_dict(checkpoint["model"])
85
- model = model.to(DEVICE)
86
- model.eval()
87
-
88
- prompt = f"A:{region}\n"
89
-
90
- tunes = ""
91
- num_tunes = args.num_tunes
92
- max_patch = args.max_patch
93
- top_p = args.top_p
94
- top_k = args.top_k
95
- temperature = args.temperature
96
- seed = args.seed
97
- show_control_code = args.show_control_code
98
-
99
- print(" HYPERPARAMETERS ".center(60, "#"), "\n")
100
- arg_dict: dict = vars(args)
101
-
102
- for key in arg_dict.keys():
103
- print(f"{key}: {str(arg_dict[key])}")
104
-
105
- print("\n", " OUTPUT TUNES ".center(60, "#"))
106
-
107
- start_time = time.time()
108
- for i in range(num_tunes):
109
- title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
110
- tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
111
- lines = re.split(r"(\n)", tune)
112
- tune = ""
113
- skip = False
114
- for line in lines:
115
- if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
116
- if not skip:
117
- print(line, end="")
118
- tune += line
119
-
120
- skip = False
121
-
122
- else:
123
- skip = True
124
-
125
- input_patches = torch.tensor(
126
- [patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=DEVICE
127
- )
128
-
129
- if tune == "":
130
- tokens = None
131
-
132
- else:
133
- prefix = patchilizer.decode(input_patches[0])
134
- remaining_tokens = prompt[len(prefix) :]
135
- tokens = torch.tensor(
136
- [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
137
- device=DEVICE,
138
- )
139
-
140
- while input_patches.shape[1] < max_patch:
141
- predicted_patch, seed = model.generate(
142
- input_patches,
143
- tokens,
144
- top_p=top_p,
145
- top_k=top_k,
146
- temperature=temperature,
147
- seed=seed,
148
- )
149
- tokens = None
150
-
151
- if predicted_patch[0] != patchilizer.eos_token_id:
152
- next_bar = patchilizer.decode([predicted_patch])
153
-
154
- if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
155
- print(next_bar, end="")
156
- tune += next_bar
157
-
158
- if next_bar == "":
159
- break
160
-
161
- next_bar = remaining_tokens + next_bar
162
- remaining_tokens = ""
163
-
164
- predicted_patch = torch.tensor(
165
- patchilizer.bar2patch(next_bar), device=DEVICE
166
- ).unsqueeze(0)
167
-
168
- input_patches = torch.cat(
169
- [input_patches, predicted_patch.unsqueeze(0)], dim=1
170
- )
171
-
172
- else:
173
- break
174
-
175
- tunes += f"{tune}\n\n"
176
- print("\n")
177
-
178
- print("Generation time: {:.2f} seconds".format(time.time() - start_time))
179
- os.makedirs(TEMP_DIR, exist_ok=True)
180
- timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
181
- try:
182
- out_midi = abc_to_midi(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.mid")
183
- out_xml = abc_to_musicxml(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
184
- out_mxl = musicxml_to_mxl(f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
185
- pdf_file, jpg_file = mxl2jpg(out_mxl)
186
- wav_file = midi2wav(out_midi)
187
-
188
- return tunes, out_midi, pdf_file, out_xml, out_mxl, jpg_file, wav_file
189
-
190
- except Exception as e:
191
- print(f"Invalid abc generated: {e}, retrying...")
192
- return generate_abc(args, epochs, region)
193
-
194
-
195
- def inference(epochs: str, region: str):
196
- Teyvat = {
197
- "蒙德": "Mondstadt",
198
- "璃月": "Liyue",
199
- "稻妻": "Inazuma",
200
- "须弥": "Sumeru",
201
- "枫丹": "Fontaine",
202
- }
203
-
204
- if os.path.exists(TEMP_DIR):
205
- shutil.rmtree(TEMP_DIR)
206
-
207
- parser = argparse.ArgumentParser()
208
- args = get_args(parser)
209
- return generate_abc(args, epochs, Teyvat[region])
210
-
211
-
212
- if __name__ == "__main__":
213
- with gr.Blocks() as demo:
214
- gr.Markdown(
215
- """<center>欢迎使用此创空间,此创空间由bilibili <a href="https://space.bilibili.com/30620472">@亦真亦幻Studio</a> 基于 Tunesformer 开源项目制作,完全免费。</center>"""
216
- )
217
- with gr.Row():
218
- with gr.Column():
219
- weight_opt = gr.Dropdown(
220
- choices=["5", "15"],
221
- value="15",
222
- label="模型选择(epochs)",
223
- )
224
- region_opt = gr.Dropdown(
225
- choices=["蒙德", "璃月", "稻妻", "须弥", "枫丹"],
226
- value="蒙德",
227
- label="地区风格",
228
- )
229
- gen_btn = gr.Button("生成")
230
- gr.Markdown(
231
- """
232
- <center>
233
- 当前模型还在调试中,由于训练数据由 MIDI 转换而来,存在大量谱面不规范问题,导致很多生成结果有谱面不规范等问题。<br>
234
-
235
- 计划在原神主线杀青后,所有国家地区角色全部开放后,二创音乐会齐全且样本均衡,届时重新微调模型并添加现实风格筛选辅助游戏各国家输出把关,以提升输出区分度与质量。
236
-
237
- 数据来源:<a href="https://musescore.org">MuseScore</a><br>
238
- Tag 嵌入数据来源:<a href="https://genshin-impact.fandom.com/wiki/Genshin_Impact_Wiki">Genshin Impact Wiki | Fandom</a><br>
239
- 模型基础:<a href="https://github.com/sander-wood/tunesformer">Tunesformer</a>
240
-
241
- 注:崩铁方面数据工程正在运作中,未来也希望随主线杀青而基线化。</center>"""
242
- )
243
-
244
- with gr.Column():
245
- wav_output = gr.Audio(label="音频", type="filepath")
246
- dld_midi = gr.components.File(label="下载 MIDI")
247
- pdf_score = gr.components.File(label="下载 PDF 乐谱")
248
- dld_xml = gr.components.File(label="下载 MusicXML")
249
- dld_mxl = gr.components.File(label="下载 MXL")
250
- abc_output = gr.Textbox(label="abc notation", show_copy_button=True)
251
- img_score = gr.Image(label="五线谱", type="filepath")
252
-
253
- gen_btn.click(
254
- inference,
255
- inputs=[weight_opt, region_opt],
256
- outputs=[
257
- abc_output,
258
- dld_midi,
259
- pdf_score,
260
- dld_xml,
261
- dld_mxl,
262
- img_score,
263
- wav_output,
264
- ],
265
- )
266
-
267
- demo.launch()
 
1
+ import re
2
+ import os
3
+ import time
4
+ import torch
5
+ import shutil
6
+ import argparse
7
+ import gradio as gr
8
+ from config import *
9
+ from utils import Patchilizer, TunesFormer, DEVICE
10
+ from convert import abc_to_midi, abc_to_musicxml, musicxml_to_mxl, mxl2jpg, midi2wav
11
+ from modelscope import snapshot_download
12
+ from transformers import GPT2Config
13
+ import warnings
14
+
15
+ warnings.filterwarnings("ignore")
16
+
17
+ # 模型下载
18
+ MODEL_DIR = snapshot_download("MuGeminorum/hoyoGPT")
19
+
20
+
21
+ def get_args(parser: argparse.ArgumentParser):
22
+ parser.add_argument(
23
+ "-num_tunes",
24
+ type=int,
25
+ default=1,
26
+ help="the number of independently computed returned tunes",
27
+ )
28
+ parser.add_argument(
29
+ "-max_patch",
30
+ type=int,
31
+ default=128,
32
+ help="integer to define the maximum length in tokens of each tune",
33
+ )
34
+ parser.add_argument(
35
+ "-top_p",
36
+ type=float,
37
+ default=0.8,
38
+ help="float to define the tokens that are within the sample operation of text generation",
39
+ )
40
+ parser.add_argument(
41
+ "-top_k",
42
+ type=int,
43
+ default=8,
44
+ help="integer to define the tokens that are within the sample operation of text generation",
45
+ )
46
+ parser.add_argument(
47
+ "-temperature",
48
+ type=float,
49
+ default=1.2,
50
+ help="the temperature of the sampling operation",
51
+ )
52
+ parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
53
+ parser.add_argument(
54
+ "-show_control_code",
55
+ type=bool,
56
+ default=False,
57
+ help="whether to show control code",
58
+ )
59
+ args = parser.parse_args()
60
+
61
+ return args
62
+
63
+
64
+ def generate_abc(args, epochs: str, region: str):
65
+ patchilizer = Patchilizer()
66
+
67
+ patch_config = GPT2Config(
68
+ num_hidden_layers=PATCH_NUM_LAYERS,
69
+ max_length=PATCH_LENGTH,
70
+ max_position_embeddings=PATCH_LENGTH,
71
+ vocab_size=1,
72
+ )
73
+
74
+ char_config = GPT2Config(
75
+ num_hidden_layers=CHAR_NUM_LAYERS,
76
+ max_length=PATCH_SIZE,
77
+ max_position_embeddings=PATCH_SIZE,
78
+ vocab_size=128,
79
+ )
80
+
81
+ model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
82
+ filename = f"{MODEL_DIR}/{epochs}/weights.pth"
83
+ checkpoint = torch.load(filename, map_location=torch.device("cpu"))
84
+ model.load_state_dict(checkpoint["model"])
85
+ model = model.to(DEVICE)
86
+ model.eval()
87
+
88
+ prompt = f"A:{region}\n"
89
+
90
+ tunes = ""
91
+ num_tunes = args.num_tunes
92
+ max_patch = args.max_patch
93
+ top_p = args.top_p
94
+ top_k = args.top_k
95
+ temperature = args.temperature
96
+ seed = args.seed
97
+ show_control_code = args.show_control_code
98
+
99
+ print(" HYPERPARAMETERS ".center(60, "#"), "\n")
100
+ arg_dict: dict = vars(args)
101
+
102
+ for key in arg_dict.keys():
103
+ print(f"{key}: {str(arg_dict[key])}")
104
+
105
+ print("\n", " OUTPUT TUNES ".center(60, "#"))
106
+
107
+ start_time = time.time()
108
+ for i in range(num_tunes):
109
+ title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
110
+ tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
111
+ lines = re.split(r"(\n)", tune)
112
+ tune = ""
113
+ skip = False
114
+ for line in lines:
115
+ if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
116
+ if not skip:
117
+ print(line, end="")
118
+ tune += line
119
+
120
+ skip = False
121
+
122
+ else:
123
+ skip = True
124
+
125
+ input_patches = torch.tensor(
126
+ [patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=DEVICE
127
+ )
128
+
129
+ if tune == "":
130
+ tokens = None
131
+
132
+ else:
133
+ prefix = patchilizer.decode(input_patches[0])
134
+ remaining_tokens = prompt[len(prefix) :]
135
+ tokens = torch.tensor(
136
+ [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
137
+ device=DEVICE,
138
+ )
139
+
140
+ while input_patches.shape[1] < max_patch:
141
+ predicted_patch, seed = model.generate(
142
+ input_patches,
143
+ tokens,
144
+ top_p=top_p,
145
+ top_k=top_k,
146
+ temperature=temperature,
147
+ seed=seed,
148
+ )
149
+ tokens = None
150
+
151
+ if predicted_patch[0] != patchilizer.eos_token_id:
152
+ next_bar = patchilizer.decode([predicted_patch])
153
+
154
+ if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
155
+ print(next_bar, end="")
156
+ tune += next_bar
157
+
158
+ if next_bar == "":
159
+ break
160
+
161
+ next_bar = remaining_tokens + next_bar
162
+ remaining_tokens = ""
163
+
164
+ predicted_patch = torch.tensor(
165
+ patchilizer.bar2patch(next_bar), device=DEVICE
166
+ ).unsqueeze(0)
167
+
168
+ input_patches = torch.cat(
169
+ [input_patches, predicted_patch.unsqueeze(0)], dim=1
170
+ )
171
+
172
+ else:
173
+ break
174
+
175
+ tunes += f"{tune}\n\n"
176
+ print("\n")
177
+
178
+ print("Generation time: {:.2f} seconds".format(time.time() - start_time))
179
+ os.makedirs(TEMP_DIR, exist_ok=True)
180
+ timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
181
+ try:
182
+ out_midi = abc_to_midi(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.mid")
183
+ out_xml = abc_to_musicxml(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
184
+ out_mxl = musicxml_to_mxl(f"{TEMP_DIR}/[{region}]{timestamp}.musicxml")
185
+ pdf_file, jpg_file = mxl2jpg(out_mxl)
186
+ wav_file = midi2wav(out_midi)
187
+
188
+ return tunes, out_midi, pdf_file, out_xml, out_mxl, jpg_file, wav_file
189
+
190
+ except Exception as e:
191
+ print(f"Invalid abc generated: {e}, retrying...")
192
+ return generate_abc(args, epochs, region)
193
+
194
+
195
+ def inference(epochs: str, region: str):
196
+ if os.path.exists(TEMP_DIR):
197
+ shutil.rmtree(TEMP_DIR)
198
+
199
+ parser = argparse.ArgumentParser()
200
+ args = get_args(parser)
201
+ return generate_abc(args, epochs, region)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ with gr.Blocks() as demo:
206
+ gr.Markdown(
207
+ """<center>欢迎使用此创空间,此创空间由bilibili <a href="https://space.bilibili.com/30620472">@亦真亦幻Studio</a> 基于 Tunesformer 开源项目制作,完全免费。</center>"""
208
+ )
209
+ with gr.Row():
210
+ with gr.Column():
211
+ weight_opt = gr.Dropdown(
212
+ choices=["5", "15"],
213
+ value="15",
214
+ label="Model Selection(epochs)",
215
+ )
216
+ region_opt = gr.Dropdown(
217
+ choices=["Mondstadt", "Liyue", "Inazuma", "Sumeru", "Fontaine"],
218
+ value="Mondstadt",
219
+ label="Region",
220
+ )
221
+ gen_btn = gr.Button("生成")
222
+ gr.Markdown(
223
+ """
224
+ <center>
225
+ Currently, the model is still under debugging, and since the training data is converted from MIDI, there are a lot of spectral irregularities, which leads to many generated results with spectral irregularities and other problems.<br>
226
+
227
+ Planned in the Genshin main line killed, all countries and regions after all the characters are open, the second creation of the concert will be complete and balanced samples, then re-fine-tune the model and add the reality of the style of screening to assist the game of the various countries output gatekeepers, in order to enhance the output of the differentiation and quality.
228
+
229
+ Data source: <a href="https://musescore.org">MuseScore</a><br>
230
+ Tag embedded data source: <a href="https://genshin-impact.fandom.com/wiki/Genshin_Impact_Wiki">Genshin Impact Wiki | Fandom</a><br>
231
+ Base model: <a href="https://github.com/sander-wood/tunesformer">Tunesformer</a>
232
+
233
+ Note: Data engineering on the Honkai: Star Rail side is in operation, and will hopefully be baselined in the future as well with the mainline kill.</center>"""
234
+ )
235
+
236
+ with gr.Column():
237
+ wav_output = gr.Audio(label="Audio", type="filepath")
238
+ dld_midi = gr.components.File(label="Download MIDI")
239
+ pdf_score = gr.components.File(label="Download PDF Score")
240
+ dld_xml = gr.components.File(label="Download MusicXML")
241
+ dld_mxl = gr.components.File(label="Download MXL")
242
+ abc_output = gr.Textbox(label="abc notation", show_copy_button=True)
243
+ img_score = gr.Image(label="Staff", type="filepath")
244
+
245
+ gen_btn.click(
246
+ inference,
247
+ inputs=[weight_opt, region_opt],
248
+ outputs=[
249
+ abc_output,
250
+ dld_midi,
251
+ pdf_score,
252
+ dld_xml,
253
+ dld_mxl,
254
+ img_score,
255
+ wav_output,
256
+ ],
257
+ )
258
+
259
+ demo.launch()