MuGeminorum commited on
Commit
653dc95
1 Parent(s): 37e9aba
Files changed (9) hide show
  1. .gitattributes +12 -11
  2. .gitignore +8 -0
  3. README.md +4 -4
  4. app.py +215 -0
  5. conda.txt +5 -0
  6. config.py +19 -0
  7. render.py +73 -0
  8. requirements.txt +9 -0
  9. utils.py +388 -0
.gitattributes CHANGED
@@ -1,35 +1,36 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ *.sf3 filter=lfs diff=lfs merge=lfs -text
36
+ *.AppImage filter=lfs diff=lfs merge=lfs -textlibnss3.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/*
2
+ output/*
3
+ rename.sh
4
+ test.py
5
+ gpt2-abcmusic/*
6
+ *.pth
7
+ tmp/*
8
+ mscore3/*
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: HoyoGPT
3
- emoji: 🐢
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
  title: HoyoGPT
3
+ emoji: 🎹
4
+ colorFrom: green
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import time
4
+ import torch
5
+ import shutil
6
+ import argparse
7
+ import gradio as gr
8
+ from utils import *
9
+ from config import *
10
+ from render import *
11
+ from music21 import converter
12
+ from transformers import GPT2Config
13
+ import warnings
14
+ warnings.filterwarnings('ignore')
15
+
16
+
17
+ def abc_to_midi(abc_content, output_midi_path):
18
+ # 解析 ABC 格式的乐谱
19
+ score = converter.parse(abc_content)
20
+
21
+ # 将乐谱保存为 MIDI 文件
22
+ score.write('midi', fp=output_midi_path)
23
+ return output_midi_path
24
+
25
+
26
+ def get_args(parser):
27
+ parser.add_argument('-num_tunes', type=int, default=1,
28
+ help='the number of independently computed returned tunes')
29
+ parser.add_argument('-max_patch', type=int, default=128,
30
+ help='integer to define the maximum length in tokens of each tune')
31
+ parser.add_argument('-top_p', type=float, default=0.8,
32
+ help='float to define the tokens that are within the sample operation of text generation')
33
+ parser.add_argument('-top_k', type=int, default=8,
34
+ help='integer to define the tokens that are within the sample operation of text generation')
35
+ parser.add_argument('-temperature', type=float, default=1.2,
36
+ help='the temperature of the sampling operation')
37
+ parser.add_argument('-seed', type=int, default=None,
38
+ help='seed for randomstate')
39
+ parser.add_argument('-show_control_code', type=bool,
40
+ default=True, help='whether to show control code')
41
+ args = parser.parse_args()
42
+
43
+ return args
44
+
45
+
46
+ def generate_abc(args, region):
47
+ patchilizer = Patchilizer()
48
+
49
+ patch_config = GPT2Config(
50
+ num_hidden_layers=PATCH_NUM_LAYERS,
51
+ max_length=PATCH_LENGTH,
52
+ max_position_embeddings=PATCH_LENGTH,
53
+ vocab_size=1
54
+ )
55
+
56
+ char_config = GPT2Config(
57
+ num_hidden_layers=CHAR_NUM_LAYERS,
58
+ max_length=PATCH_SIZE,
59
+ max_position_embeddings=PATCH_SIZE,
60
+ vocab_size=128
61
+ )
62
+
63
+ model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
64
+
65
+ filename = WEIGHT_PATH
66
+
67
+ if os.path.exists(filename):
68
+ print(f"Weights already exist at '{filename}'. Loading...")
69
+
70
+ else:
71
+ download()
72
+
73
+ checkpoint = torch.load(filename, map_location=torch.device('cpu'))
74
+ model.load_state_dict(checkpoint['model'])
75
+ model = model.to(device)
76
+ model.eval()
77
+
78
+ prompt = template(region)
79
+
80
+ tunes = ""
81
+ num_tunes = args.num_tunes
82
+ max_patch = args.max_patch
83
+ top_p = args.top_p
84
+ top_k = args.top_k
85
+ temperature = args.temperature
86
+ seed = args.seed
87
+ show_control_code = args.show_control_code
88
+
89
+ print(" HYPERPARAMETERS ".center(60, "#"), '\n')
90
+ args = vars(args)
91
+
92
+ for key in args.keys():
93
+ print(f'{key}: {str(args[key])}')
94
+
95
+ print('\n', " OUTPUT TUNES ".center(60, "#"))
96
+
97
+ start_time = time.time()
98
+
99
+ for i in range(num_tunes):
100
+ tune = f"X:{str(i + 1)}\n{prompt}"
101
+ lines = re.split(r'(\n)', tune)
102
+ tune = ""
103
+ skip = False
104
+ for line in lines:
105
+ if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
106
+ if not skip:
107
+ print(line, end="")
108
+ tune += line
109
+
110
+ skip = False
111
+
112
+ else:
113
+ skip = True
114
+
115
+ input_patches = torch.tensor(
116
+ [patchilizer.encode(prompt, add_special_patches=True)[:-1]],
117
+ device=device
118
+ )
119
+
120
+ if tune == "":
121
+ tokens = None
122
+
123
+ else:
124
+ prefix = patchilizer.decode(input_patches[0])
125
+ remaining_tokens = prompt[len(prefix):]
126
+ tokens = torch.tensor(
127
+ [patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens],
128
+ device=device
129
+ )
130
+
131
+ while input_patches.shape[1] < max_patch:
132
+ predicted_patch, seed = model.generate(
133
+ input_patches,
134
+ tokens,
135
+ top_p=top_p,
136
+ top_k=top_k,
137
+ temperature=temperature,
138
+ seed=seed
139
+ )
140
+ tokens = None
141
+
142
+ if predicted_patch[0] != patchilizer.eos_token_id:
143
+ next_bar = patchilizer.decode([predicted_patch])
144
+
145
+ if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
146
+ print(next_bar, end="")
147
+ tune += next_bar
148
+
149
+ if next_bar == "":
150
+ break
151
+
152
+ next_bar = remaining_tokens+next_bar
153
+ remaining_tokens = ""
154
+
155
+ predicted_patch = torch.tensor(
156
+ patchilizer.bar2patch(next_bar),
157
+ device=device
158
+ ).unsqueeze(0)
159
+
160
+ input_patches = torch.cat(
161
+ [input_patches, predicted_patch.unsqueeze(0)],
162
+ dim=1
163
+ )
164
+
165
+ else:
166
+ break
167
+
168
+ tunes += f"{tune}\n\n"
169
+ print("\n")
170
+
171
+ print("Generation time: {:.2f} seconds".format(time.time() - start_time))
172
+ create_dir('./tmp')
173
+ timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
174
+ out_midi = abc_to_midi(tunes, f'./tmp/[{region}]{timestamp}.mid')
175
+ add_path()
176
+ png_file = midi2png(out_midi)
177
+ wav_file = midi2wav(out_midi)
178
+
179
+ return tunes, out_midi, png_file, wav_file
180
+
181
+
182
+ def inference(region):
183
+ if os.path.exists('./tmp'):
184
+ shutil.rmtree('./tmp')
185
+
186
+ parser = argparse.ArgumentParser()
187
+ args = get_args(parser)
188
+ return generate_abc(args, region)
189
+
190
+
191
+ with gr.Blocks() as demo:
192
+ with gr.Row():
193
+ with gr.Column():
194
+ region_opt = gr.Dropdown(
195
+ choices=[
196
+ 'Mondstadt', 'Liyue', 'Inazuma', 'Sumeru', 'Fontaine'
197
+ ],
198
+ value='Liyue',
199
+ label='Region'
200
+ )
201
+ gen_btn = gr.Button("Generate")
202
+
203
+ with gr.Column():
204
+ wav_output = gr.Audio(label='Audio', type='filepath')
205
+ dld_midi = gr.components.File(label="Download MIDI")
206
+ abc_output = gr.TextArea(label='abc score')
207
+ img_score = gr.Image(label='Staff', type='filepath')
208
+
209
+ gen_btn.click(
210
+ inference,
211
+ inputs=region_opt,
212
+ outputs=[abc_output, dld_midi, img_score, wav_output]
213
+ )
214
+
215
+ demo.launch(share=True)
conda.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ python=3.10
2
+ pytorch=1.12.1
3
+ torchvision=0.13.1
4
+ torchaudio=0.12.1
5
+ cudatoolkit=11.3.1
config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PATCH_LENGTH = 128 # Patch Length
2
+ PATCH_SIZE = 32 # Patch Size
3
+
4
+ PATCH_NUM_LAYERS = 9 # Number of layers in the encoder
5
+ CHAR_NUM_LAYERS = 3 # Number of layers in the decoder
6
+
7
+ # Number of epochs to train for (if early stopping doesn't intervene)
8
+ NUM_EPOCHS = 5 # 32
9
+ LEARNING_RATE = 5e-5 # Learning rate for the optimizer
10
+ # Batch size for patch during training, 0 for full context
11
+ PATCH_SAMPLING_BATCH_SIZE = 0
12
+ LOAD_FROM_CHECKPOINT = True # Whether to load weights from a checkpoint
13
+ # Whether to share weights between the encoder and decoder
14
+ SHARE_WEIGHTS = False
15
+ WEIGHT_URL = 'https://huggingface.co/MuGeminorum/hoyoGPT/resolve/main/weights.pth'
16
+ ZH_WEIGHT_URL = 'https://www.modelscope.cn/api/v1/models/MuGeminorum/hoyoGPT/repo?Revision=master&FilePath=weights.pth'
17
+ WEIGHT_PATH = 'weights.pth'
18
+ LOG_PATH = 'logs.txt'
19
+ PROMPT_PATH = 'prompt.txt'
render.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ from PIL import Image
5
+ from pdf2image import convert_from_path
6
+ from utils import download
7
+
8
+
9
+ def add_path():
10
+ # """
11
+ # 将指定目录添加到 LD_LIBRARY_PATH 环境变量中,并在当前 Python 进程中生效。
12
+
13
+ # Parameters:
14
+ # - directory_path (str): 要添加的目录路径。
15
+ # """
16
+ # dir_path = os.path.join(os.getcwd(), 'lib')
17
+ # # 获取当前环境变量的值
18
+ # current_path = os.environ.get("LD_LIBRARY_PATH", "")
19
+
20
+ # # 将目录路径添加到 LD_LIBRARY_PATH 中
21
+ # new_path = f"{current_path}:{dir_path}"
22
+
23
+ # # 设置 LD_LIBRARY_PATH 环境变量,以便在当前 Python 进程中生效
24
+ # os.environ["LD_LIBRARY_PATH"] = new_path
25
+ os.environ['QT_QPA_PLATFORM'] = 'offscreen'
26
+
27
+
28
+ if sys.platform.startswith('linux'):
29
+ apkname = 'MuseScore.AppImage'
30
+ extra_dir = 'squashfs-root'
31
+ download(
32
+ filename=apkname,
33
+ url='https://cdn.jsdelivr.net/musescore/v4.2.0/MuseScore-4.2.0.233521125-x86_64.AppImage'
34
+ )
35
+ if not os.path.exists(extra_dir):
36
+ subprocess.run(['chmod', '+x', f'./{apkname}'])
37
+ subprocess.run([f'./{apkname}', '--appimage-extract'])
38
+
39
+ mscore = f'./{extra_dir}/AppRun'
40
+
41
+ else:
42
+ mscore = "D:/Program Files/MuseScore 3/bin/MuseScore3.exe"
43
+
44
+
45
+ def midi2wav(mid_file: str):
46
+ wav_file = mid_file.replace('.mid', '.wav')
47
+ command = [mscore, "-o", wav_file, mid_file]
48
+ result = subprocess.run(command)
49
+ print(result)
50
+ return wav_file
51
+
52
+
53
+ def pdf_to_img(pdf_path: str):
54
+ output_path = pdf_path.replace('.pdf', '.jpg')
55
+ images = convert_from_path(pdf_path)
56
+ combined_image = Image.new(
57
+ 'RGB', (images[0].width, sum(image.height for image in images))
58
+ )
59
+ y_offset = 0
60
+ for image in images:
61
+ combined_image.paste(image, (0, y_offset))
62
+ y_offset += image.height
63
+
64
+ combined_image.save(output_path)
65
+ return output_path
66
+
67
+
68
+ def midi2png(mid_file: str):
69
+ pdf_score = mid_file.replace('.mid', '.pdf')
70
+ command = [mscore, "-o", pdf_score, mid_file]
71
+ result = subprocess.run(command)
72
+ print(result)
73
+ return pdf_to_img(pdf_score)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.18.0
2
+ samplings==0.1.7
3
+ unidecode
4
+ music21
5
+ autopep8
6
+ pillow==9.4.0
7
+ gradio
8
+ pdf2image
9
+ torch
utils.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import random
5
+ from config import *
6
+ from tqdm import tqdm
7
+ from unidecode import unidecode
8
+ from torch.utils.data import Dataset
9
+ from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel
10
+ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
11
+
12
+ if torch.cuda.is_available():
13
+ device = torch.device("cuda")
14
+ else:
15
+ device = torch.device("cpu")
16
+
17
+
18
+ def template(region):
19
+ return f'''A:{region}
20
+ S:2
21
+ B:9
22
+ E:4
23
+ B:9
24
+ L:1/8
25
+ M:3/4
26
+ K:D
27
+ de |"D"'''
28
+
29
+
30
+ def create_dir(dir_path):
31
+ if not os.path.exists(dir_path):
32
+ os.makedirs(dir_path)
33
+
34
+
35
+ def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
36
+ import time
37
+ import requests
38
+ try:
39
+ response = requests.get(url, stream=True)
40
+ total_size = int(response.headers.get('content-length', 0))
41
+ chunk_size = 1024
42
+
43
+ with open(filename, 'wb') as file, tqdm(
44
+ desc=f"Downloading weights to '{filename}'...",
45
+ total=total_size,
46
+ unit='B',
47
+ unit_scale=True,
48
+ unit_divisor=1024,
49
+ ) as bar:
50
+ for data in response.iter_content(chunk_size=chunk_size):
51
+ size = file.write(data)
52
+ bar.update(size)
53
+
54
+ except ConnectionError as e:
55
+ print(f"Error: {e}")
56
+ time.sleep(3)
57
+ download(filename, ZH_WEIGHT_URL)
58
+
59
+
60
+ class Patchilizer:
61
+ """
62
+ A class for converting music bars to patches and vice versa.
63
+ """
64
+
65
+ def __init__(self):
66
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
67
+ self.regexPattern = f"({'|'.join(map(re.escape, self.delimiters))})"
68
+ self.pad_token_id = 0
69
+ self.bos_token_id = 1
70
+ self.eos_token_id = 2
71
+
72
+ def split_bars(self, body):
73
+ """
74
+ Split a body of music into individual bars.
75
+ """
76
+ bars = re.split(self.regexPattern, ''.join(body))
77
+ bars = list(filter(None, bars))
78
+ # remove empty strings
79
+ if bars[0] in self.delimiters:
80
+ bars[1] = bars[0] + bars[1]
81
+ bars = bars[1:]
82
+
83
+ bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
84
+ return bars
85
+
86
+ def bar2patch(self, bar, patch_size=PATCH_SIZE):
87
+ """
88
+ Convert a bar into a patch of specified length.
89
+ """
90
+ patch = [self.bos_token_id] + \
91
+ [ord(c) for c in bar] + [self.eos_token_id]
92
+ patch = patch[:patch_size]
93
+ patch += [self.pad_token_id] * (patch_size - len(patch))
94
+ return patch
95
+
96
+ def patch2bar(self, patch):
97
+ """
98
+ Convert a patch into a bar.
99
+ """
100
+ return ''.join(chr(idx) if idx > self.eos_token_id else '' for idx in patch if idx != self.eos_token_id)
101
+
102
+ def encode(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=False):
103
+ """
104
+ Encode music into patches of specified length.
105
+ """
106
+ lines = unidecode(abc_code).split('\n')
107
+ lines = list(filter(None, lines)) # remove empty lines
108
+
109
+ body = ""
110
+ patches = []
111
+
112
+ for line in lines:
113
+ if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%score')):
114
+ if body:
115
+ bars = self.split_bars(body)
116
+ patches.extend(
117
+ self.bar2patch(bar + '\n' if idx == len(bars) - 1 else bar, patch_size) for idx, bar in enumerate(bars)
118
+ )
119
+ body = ""
120
+
121
+ patches.append(self.bar2patch(line + '\n', patch_size))
122
+
123
+ else:
124
+ body += line + '\n'
125
+
126
+ if body:
127
+ patches.extend(
128
+ self.bar2patch(bar, patch_size) for bar in self.split_bars(body)
129
+ )
130
+
131
+ if add_special_patches:
132
+ bos_patch = [self.bos_token_id] * \
133
+ (patch_size-1) + [self.eos_token_id]
134
+ eos_patch = [self.bos_token_id] + \
135
+ [self.eos_token_id] * (patch_size-1)
136
+ patches = [bos_patch] + patches + [eos_patch]
137
+
138
+ return patches[:patch_length]
139
+
140
+ def decode(self, patches):
141
+ """
142
+ Decode patches into music.
143
+ """
144
+ return ''.join(self.patch2bar(patch) for patch in patches)
145
+
146
+
147
+ class PatchLevelDecoder(PreTrainedModel):
148
+ """
149
+ An Patch-level Decoder model for generating patch features in an auto-regressive manner.
150
+ It inherits PreTrainedModel from transformers.
151
+ """
152
+
153
+ def __init__(self, config):
154
+ super().__init__(config)
155
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
156
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
157
+ self.base = GPT2Model(config)
158
+
159
+ def forward(self, patches: torch.Tensor) -> torch.Tensor:
160
+ """
161
+ The forward pass of the patch-level decoder model.
162
+ :param patches: the patches to be encoded
163
+ :return: the encoded patches
164
+ """
165
+ patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
166
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
167
+ patches = self.patch_embedding(patches.to(self.device))
168
+
169
+ return self.base(inputs_embeds=patches)
170
+
171
+
172
+ class CharLevelDecoder(PreTrainedModel):
173
+ """
174
+ A Char-level Decoder model for generating the characters within each bar patch sequentially.
175
+ It inherits PreTrainedModel from transformers.
176
+ """
177
+
178
+ def __init__(self, config):
179
+ super().__init__(config)
180
+ self.pad_token_id = 0
181
+ self.bos_token_id = 1
182
+ self.eos_token_id = 2
183
+ self.base = GPT2LMHeadModel(config)
184
+
185
+ def forward(self, encoded_patches: torch.Tensor, target_patches: torch.Tensor, patch_sampling_batch_size: int):
186
+ """
187
+ The forward pass of the char-level decoder model.
188
+ :param encoded_patches: the encoded patches
189
+ :param target_patches: the target patches
190
+ :return: the decoded patches
191
+ """
192
+ # preparing the labels for model training
193
+ target_masks = target_patches == self.pad_token_id
194
+ labels = target_patches.clone().masked_fill_(target_masks, -100)
195
+
196
+ # masking the labels for model training
197
+ target_masks = torch.ones_like(labels)
198
+ target_masks = target_masks.masked_fill_(labels == -100, 0)
199
+
200
+ # select patches
201
+ if patch_sampling_batch_size != 0 and patch_sampling_batch_size < target_patches.shape[0]:
202
+ indices = list(range(len(target_patches)))
203
+ random.shuffle(indices)
204
+ selected_indices = sorted(indices[:patch_sampling_batch_size])
205
+
206
+ target_patches = target_patches[selected_indices, :]
207
+ target_masks = target_masks[selected_indices, :]
208
+ encoded_patches = encoded_patches[selected_indices, :]
209
+ labels = labels[selected_indices, :]
210
+
211
+ # get input embeddings
212
+ inputs_embeds = torch.nn.functional.embedding(
213
+ target_patches,
214
+ self.base.transformer.wte.weight
215
+ )
216
+
217
+ # concatenate the encoded patches with the input embeddings
218
+ inputs_embeds = torch.cat(
219
+ (encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]),
220
+ dim=1
221
+ )
222
+
223
+ return self.base(
224
+ inputs_embeds=inputs_embeds,
225
+ attention_mask=target_masks,
226
+ labels=labels
227
+ )
228
+
229
+ def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
230
+ """
231
+ The generate function for generating a patch based on the encoded patch and already generated tokens.
232
+ :param encoded_patch: the encoded patch
233
+ :param tokens: already generated tokens in the patch
234
+ :return: the probability distribution of next token
235
+ """
236
+ encoded_patch = encoded_patch.reshape(1, 1, -1)
237
+ tokens = tokens.reshape(1, -1)
238
+
239
+ # Get input embeddings
240
+ tokens = torch.nn.functional.embedding(
241
+ tokens,
242
+ self.base.transformer.wte.weight
243
+ )
244
+
245
+ # Concatenate the encoded patch with the input embeddings
246
+ tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
247
+
248
+ # Get output from model
249
+ outputs = self.base(inputs_embeds=tokens)
250
+
251
+ # Get probabilities of next token
252
+ probs = torch.nn.functional.softmax(
253
+ outputs.logits.squeeze(0)[-1],
254
+ dim=-1
255
+ )
256
+
257
+ return probs
258
+
259
+
260
+ class TunesFormer(PreTrainedModel):
261
+ """
262
+ TunesFormer is a hierarchical music generation model based on bar patching.
263
+ It includes a patch-level decoder and a character-level decoder.
264
+ It inherits PreTrainedModel from transformers.
265
+ """
266
+
267
+ def __init__(self, encoder_config, decoder_config, share_weights=False):
268
+ super().__init__(encoder_config)
269
+ self.pad_token_id = 0
270
+ self.bos_token_id = 1
271
+ self.eos_token_id = 2
272
+ if share_weights:
273
+ max_layers = max(
274
+ encoder_config.num_hidden_layers,
275
+ decoder_config.num_hidden_layers
276
+ )
277
+
278
+ max_context_size = max(
279
+ encoder_config.max_length,
280
+ decoder_config.max_length
281
+ )
282
+
283
+ max_position_embeddings = max(
284
+ encoder_config.max_position_embeddings,
285
+ decoder_config.max_position_embeddings
286
+ )
287
+
288
+ encoder_config.num_hidden_layers = max_layers
289
+ encoder_config.max_length = max_context_size
290
+ encoder_config.max_position_embeddings = max_position_embeddings
291
+ decoder_config.num_hidden_layers = max_layers
292
+ decoder_config.max_length = max_context_size
293
+ decoder_config.max_position_embeddings = max_position_embeddings
294
+
295
+ self.patch_level_decoder = PatchLevelDecoder(encoder_config)
296
+ self.char_level_decoder = CharLevelDecoder(decoder_config)
297
+
298
+ if share_weights:
299
+ self.patch_level_decoder.base = self.char_level_decoder.base.transformer
300
+
301
+ def forward(self, patches: torch.Tensor, patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE):
302
+ """
303
+ The forward pass of the TunesFormer model.
304
+ :param patches: the patches to be both encoded and decoded
305
+ :return: the decoded patches
306
+ """
307
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
308
+ encoded_patches = self.patch_level_decoder(
309
+ patches)["last_hidden_state"]
310
+
311
+ return self.char_level_decoder(encoded_patches.squeeze(0)[:-1, :], patches.squeeze(0)[1:, :], patch_sampling_batch_size)
312
+
313
+ def generate(
314
+ self,
315
+ patches: torch.Tensor,
316
+ tokens: torch.Tensor,
317
+ top_p: float = 1,
318
+ top_k: int = 0,
319
+ temperature: float = 1,
320
+ seed: int = None
321
+ ):
322
+ """
323
+ The generate function for generating patches based on patches.
324
+ :param patches: the patches to be encoded
325
+ :return: the generated patches
326
+ """
327
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
328
+ encoded_patches = self.patch_level_decoder(
329
+ patches)["last_hidden_state"]
330
+
331
+ if tokens == None:
332
+ tokens = torch.tensor([self.bos_token_id], device=self.device)
333
+
334
+ generated_patch = []
335
+ random.seed(seed)
336
+
337
+ while True:
338
+ if seed != None:
339
+ n_seed = random.randint(0, 1000000)
340
+ random.seed(n_seed)
341
+
342
+ else:
343
+ n_seed = None
344
+
345
+ prob = self.char_level_decoder.generate(
346
+ encoded_patches[0][-1],
347
+ tokens
348
+ ).cpu().detach().numpy()
349
+
350
+ prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
351
+ prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
352
+
353
+ token = temperature_sampling(
354
+ prob,
355
+ temperature=temperature,
356
+ seed=n_seed
357
+ )
358
+
359
+ generated_patch.append(token)
360
+ if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
361
+ break
362
+
363
+ else:
364
+ tokens = torch.cat(
365
+ (tokens, torch.tensor([token], device=self.device)),
366
+ dim=0
367
+ )
368
+
369
+ return generated_patch, n_seed
370
+
371
+
372
+ class PatchilizedData(Dataset):
373
+ def __init__(self, items, patchilizer):
374
+ self.texts = []
375
+
376
+ for item in tqdm(items):
377
+ text = item['control code'] + \
378
+ "\n".join(item['abc notation'].split('\n')[1:])
379
+ input_patch = patchilizer.encode(text, add_special_patches=True)
380
+ input_patch = torch.tensor(input_patch)
381
+ if torch.sum(input_patch) != 0:
382
+ self.texts.append(input_patch)
383
+
384
+ def __len__(self):
385
+ return len(self.texts)
386
+
387
+ def __getitem__(self, idx):
388
+ return self.texts[idx]