Spaces:
Running
Running
MuGeminorum
commited on
Commit
•
27cf0c7
1
Parent(s):
f9ed6a5
add copy btn
Browse files
app.py
CHANGED
@@ -10,24 +10,48 @@ from config import *
|
|
10 |
from convert import *
|
11 |
from transformers import GPT2Config
|
12 |
import warnings
|
13 |
-
|
|
|
14 |
|
15 |
|
16 |
def get_args(parser):
|
17 |
-
parser.add_argument(
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
parser.add_argument(
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
parser.add_argument(
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
args = parser.parse_args()
|
32 |
|
33 |
return args
|
@@ -40,14 +64,14 @@ def generate_abc(args, region):
|
|
40 |
num_hidden_layers=PATCH_NUM_LAYERS,
|
41 |
max_length=PATCH_LENGTH,
|
42 |
max_position_embeddings=PATCH_LENGTH,
|
43 |
-
vocab_size=1
|
44 |
)
|
45 |
|
46 |
char_config = GPT2Config(
|
47 |
num_hidden_layers=CHAR_NUM_LAYERS,
|
48 |
max_length=PATCH_SIZE,
|
49 |
max_position_embeddings=PATCH_SIZE,
|
50 |
-
vocab_size=128
|
51 |
)
|
52 |
|
53 |
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
|
@@ -60,8 +84,8 @@ def generate_abc(args, region):
|
|
60 |
else:
|
61 |
download()
|
62 |
|
63 |
-
checkpoint = torch.load(filename, map_location=torch.device(
|
64 |
-
model.load_state_dict(checkpoint[
|
65 |
model = model.to(device)
|
66 |
model.eval()
|
67 |
|
@@ -76,20 +100,20 @@ def generate_abc(args, region):
|
|
76 |
seed = args.seed
|
77 |
show_control_code = args.show_control_code
|
78 |
|
79 |
-
print(" HYPERPARAMETERS ".center(60, "#"),
|
80 |
args = vars(args)
|
81 |
|
82 |
for key in args.keys():
|
83 |
-
print(f
|
84 |
|
85 |
-
print(
|
86 |
|
87 |
start_time = time.time()
|
88 |
|
89 |
for i in range(num_tunes):
|
90 |
-
title_artist = f
|
91 |
tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
|
92 |
-
lines = re.split(r
|
93 |
tune = ""
|
94 |
skip = False
|
95 |
for line in lines:
|
@@ -104,8 +128,7 @@ def generate_abc(args, region):
|
|
104 |
skip = True
|
105 |
|
106 |
input_patches = torch.tensor(
|
107 |
-
[patchilizer.encode(prompt, add_special_patches=True)[:-1]],
|
108 |
-
device=device
|
109 |
)
|
110 |
|
111 |
if tune == "":
|
@@ -113,10 +136,10 @@ def generate_abc(args, region):
|
|
113 |
|
114 |
else:
|
115 |
prefix = patchilizer.decode(input_patches[0])
|
116 |
-
remaining_tokens = prompt[len(prefix):]
|
117 |
tokens = torch.tensor(
|
118 |
-
[patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens],
|
119 |
-
device=device
|
120 |
)
|
121 |
|
122 |
while input_patches.shape[1] < max_patch:
|
@@ -126,7 +149,7 @@ def generate_abc(args, region):
|
|
126 |
top_p=top_p,
|
127 |
top_k=top_k,
|
128 |
temperature=temperature,
|
129 |
-
seed=seed
|
130 |
)
|
131 |
tokens = None
|
132 |
|
@@ -140,17 +163,15 @@ def generate_abc(args, region):
|
|
140 |
if next_bar == "":
|
141 |
break
|
142 |
|
143 |
-
next_bar = remaining_tokens+next_bar
|
144 |
remaining_tokens = ""
|
145 |
|
146 |
predicted_patch = torch.tensor(
|
147 |
-
patchilizer.bar2patch(next_bar),
|
148 |
-
device=device
|
149 |
).unsqueeze(0)
|
150 |
|
151 |
input_patches = torch.cat(
|
152 |
-
[input_patches, predicted_patch.unsqueeze(0)],
|
153 |
-
dim=1
|
154 |
)
|
155 |
|
156 |
else:
|
@@ -160,11 +181,11 @@ def generate_abc(args, region):
|
|
160 |
print("\n")
|
161 |
|
162 |
print("Generation time: {:.2f} seconds".format(time.time() - start_time))
|
163 |
-
create_dir(
|
164 |
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
|
165 |
-
out_midi = abc_to_midi(tunes, f
|
166 |
-
out_xml = abc_to_musicxml(tunes, f
|
167 |
-
out_mxl = musicxml_to_mxl(f
|
168 |
pdf_file, jpg_file = mxl2jpg(out_mxl)
|
169 |
wav_file = midi2wav(out_midi)
|
170 |
|
@@ -172,8 +193,8 @@ def generate_abc(args, region):
|
|
172 |
|
173 |
|
174 |
def inference(region):
|
175 |
-
if os.path.exists(
|
176 |
-
shutil.rmtree(
|
177 |
|
178 |
parser = argparse.ArgumentParser()
|
179 |
args = get_args(parser)
|
@@ -184,30 +205,33 @@ with gr.Blocks() as demo:
|
|
184 |
with gr.Row():
|
185 |
with gr.Column():
|
186 |
region_opt = gr.Dropdown(
|
187 |
-
choices=[
|
188 |
-
|
189 |
-
|
190 |
-
value='Mondstadt',
|
191 |
-
label='Region genre'
|
192 |
)
|
193 |
gen_btn = gr.Button("Generate")
|
194 |
|
195 |
with gr.Column():
|
196 |
-
wav_output = gr.Audio(label=
|
197 |
dld_midi = gr.components.File(label="Download MIDI")
|
198 |
pdf_score = gr.components.File(label="Download PDF score")
|
199 |
dld_xml = gr.components.File(label="Download MusicXML")
|
200 |
dld_mxl = gr.components.File(label="Download MXL")
|
201 |
-
abc_output = gr.
|
202 |
-
img_score = gr.Image(label=
|
203 |
|
204 |
gen_btn.click(
|
205 |
inference,
|
206 |
inputs=region_opt,
|
207 |
outputs=[
|
208 |
-
abc_output,
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
211 |
)
|
212 |
|
213 |
demo.launch(share=True)
|
|
|
10 |
from convert import *
|
11 |
from transformers import GPT2Config
|
12 |
import warnings
|
13 |
+
|
14 |
+
warnings.filterwarnings("ignore")
|
15 |
|
16 |
|
17 |
def get_args(parser):
|
18 |
+
parser.add_argument(
|
19 |
+
"-num_tunes",
|
20 |
+
type=int,
|
21 |
+
default=1,
|
22 |
+
help="the number of independently computed returned tunes",
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"-max_patch",
|
26 |
+
type=int,
|
27 |
+
default=128,
|
28 |
+
help="integer to define the maximum length in tokens of each tune",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"-top_p",
|
32 |
+
type=float,
|
33 |
+
default=0.8,
|
34 |
+
help="float to define the tokens that are within the sample operation of text generation",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"-top_k",
|
38 |
+
type=int,
|
39 |
+
default=8,
|
40 |
+
help="integer to define the tokens that are within the sample operation of text generation",
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"-temperature",
|
44 |
+
type=float,
|
45 |
+
default=1.2,
|
46 |
+
help="the temperature of the sampling operation",
|
47 |
+
)
|
48 |
+
parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
|
49 |
+
parser.add_argument(
|
50 |
+
"-show_control_code",
|
51 |
+
type=bool,
|
52 |
+
default=True,
|
53 |
+
help="whether to show control code",
|
54 |
+
)
|
55 |
args = parser.parse_args()
|
56 |
|
57 |
return args
|
|
|
64 |
num_hidden_layers=PATCH_NUM_LAYERS,
|
65 |
max_length=PATCH_LENGTH,
|
66 |
max_position_embeddings=PATCH_LENGTH,
|
67 |
+
vocab_size=1,
|
68 |
)
|
69 |
|
70 |
char_config = GPT2Config(
|
71 |
num_hidden_layers=CHAR_NUM_LAYERS,
|
72 |
max_length=PATCH_SIZE,
|
73 |
max_position_embeddings=PATCH_SIZE,
|
74 |
+
vocab_size=128,
|
75 |
)
|
76 |
|
77 |
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
|
|
|
84 |
else:
|
85 |
download()
|
86 |
|
87 |
+
checkpoint = torch.load(filename, map_location=torch.device("cpu"))
|
88 |
+
model.load_state_dict(checkpoint["model"])
|
89 |
model = model.to(device)
|
90 |
model.eval()
|
91 |
|
|
|
100 |
seed = args.seed
|
101 |
show_control_code = args.show_control_code
|
102 |
|
103 |
+
print(" HYPERPARAMETERS ".center(60, "#"), "\n")
|
104 |
args = vars(args)
|
105 |
|
106 |
for key in args.keys():
|
107 |
+
print(f"{key}: {str(args[key])}")
|
108 |
|
109 |
+
print("\n", " OUTPUT TUNES ".center(60, "#"))
|
110 |
|
111 |
start_time = time.time()
|
112 |
|
113 |
for i in range(num_tunes):
|
114 |
+
title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
|
115 |
tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
|
116 |
+
lines = re.split(r"(\n)", tune)
|
117 |
tune = ""
|
118 |
skip = False
|
119 |
for line in lines:
|
|
|
128 |
skip = True
|
129 |
|
130 |
input_patches = torch.tensor(
|
131 |
+
[patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=device
|
|
|
132 |
)
|
133 |
|
134 |
if tune == "":
|
|
|
136 |
|
137 |
else:
|
138 |
prefix = patchilizer.decode(input_patches[0])
|
139 |
+
remaining_tokens = prompt[len(prefix) :]
|
140 |
tokens = torch.tensor(
|
141 |
+
[patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
|
142 |
+
device=device,
|
143 |
)
|
144 |
|
145 |
while input_patches.shape[1] < max_patch:
|
|
|
149 |
top_p=top_p,
|
150 |
top_k=top_k,
|
151 |
temperature=temperature,
|
152 |
+
seed=seed,
|
153 |
)
|
154 |
tokens = None
|
155 |
|
|
|
163 |
if next_bar == "":
|
164 |
break
|
165 |
|
166 |
+
next_bar = remaining_tokens + next_bar
|
167 |
remaining_tokens = ""
|
168 |
|
169 |
predicted_patch = torch.tensor(
|
170 |
+
patchilizer.bar2patch(next_bar), device=device
|
|
|
171 |
).unsqueeze(0)
|
172 |
|
173 |
input_patches = torch.cat(
|
174 |
+
[input_patches, predicted_patch.unsqueeze(0)], dim=1
|
|
|
175 |
)
|
176 |
|
177 |
else:
|
|
|
181 |
print("\n")
|
182 |
|
183 |
print("Generation time: {:.2f} seconds".format(time.time() - start_time))
|
184 |
+
create_dir("./tmp")
|
185 |
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
|
186 |
+
out_midi = abc_to_midi(tunes, f"./tmp/[{region}]{timestamp}.mid")
|
187 |
+
out_xml = abc_to_musicxml(tunes, f"./tmp/[{region}]{timestamp}.musicxml")
|
188 |
+
out_mxl = musicxml_to_mxl(f"./tmp/[{region}]{timestamp}.musicxml")
|
189 |
pdf_file, jpg_file = mxl2jpg(out_mxl)
|
190 |
wav_file = midi2wav(out_midi)
|
191 |
|
|
|
193 |
|
194 |
|
195 |
def inference(region):
|
196 |
+
if os.path.exists("./tmp"):
|
197 |
+
shutil.rmtree("./tmp")
|
198 |
|
199 |
parser = argparse.ArgumentParser()
|
200 |
args = get_args(parser)
|
|
|
205 |
with gr.Row():
|
206 |
with gr.Column():
|
207 |
region_opt = gr.Dropdown(
|
208 |
+
choices=["Mondstadt", "Liyue", "Inazuma", "Sumeru", "Fontaine"],
|
209 |
+
value="Mondstadt",
|
210 |
+
label="Region genre",
|
|
|
|
|
211 |
)
|
212 |
gen_btn = gr.Button("Generate")
|
213 |
|
214 |
with gr.Column():
|
215 |
+
wav_output = gr.Audio(label="Audio", type="filepath")
|
216 |
dld_midi = gr.components.File(label="Download MIDI")
|
217 |
pdf_score = gr.components.File(label="Download PDF score")
|
218 |
dld_xml = gr.components.File(label="Download MusicXML")
|
219 |
dld_mxl = gr.components.File(label="Download MXL")
|
220 |
+
abc_output = gr.Textbox(label="abc score", show_copy_button=True)
|
221 |
+
img_score = gr.Image(label="Staff", type="filepath")
|
222 |
|
223 |
gen_btn.click(
|
224 |
inference,
|
225 |
inputs=region_opt,
|
226 |
outputs=[
|
227 |
+
abc_output,
|
228 |
+
dld_midi,
|
229 |
+
pdf_score,
|
230 |
+
dld_xml,
|
231 |
+
dld_mxl,
|
232 |
+
img_score,
|
233 |
+
wav_output,
|
234 |
+
],
|
235 |
)
|
236 |
|
237 |
demo.launch(share=True)
|
utils.py
CHANGED
@@ -35,15 +35,16 @@ def create_dir(dir_path):
|
|
35 |
def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
|
36 |
import time
|
37 |
import requests
|
|
|
38 |
try:
|
39 |
response = requests.get(url, stream=True)
|
40 |
-
total_size = int(response.headers.get(
|
41 |
chunk_size = 1024
|
42 |
|
43 |
-
with open(filename,
|
44 |
desc=f"Downloading weights to '{filename}'...",
|
45 |
total=total_size,
|
46 |
-
unit=
|
47 |
unit_scale=True,
|
48 |
unit_divisor=1024,
|
49 |
) as bar:
|
@@ -51,7 +52,7 @@ def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
|
|
51 |
size = file.write(data)
|
52 |
bar.update(size)
|
53 |
|
54 |
-
except
|
55 |
print(f"Error: {e}")
|
56 |
time.sleep(3)
|
57 |
download(filename, ZH_WEIGHT_URL)
|
@@ -59,7 +60,7 @@ def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
|
|
59 |
|
60 |
class Patchilizer:
|
61 |
"""
|
62 |
-
A class for converting music bars to patches and vice versa.
|
63 |
"""
|
64 |
|
65 |
def __init__(self):
|
@@ -73,7 +74,7 @@ class Patchilizer:
|
|
73 |
"""
|
74 |
Split a body of music into individual bars.
|
75 |
"""
|
76 |
-
bars = re.split(self.regexPattern,
|
77 |
bars = list(filter(None, bars))
|
78 |
# remove empty strings
|
79 |
if bars[0] in self.delimiters:
|
@@ -87,8 +88,7 @@ class Patchilizer:
|
|
87 |
"""
|
88 |
Convert a bar into a patch of specified length.
|
89 |
"""
|
90 |
-
patch = [self.bos_token_id] +
|
91 |
-
[ord(c) for c in bar] + [self.eos_token_id]
|
92 |
patch = patch[:patch_size]
|
93 |
patch += [self.pad_token_id] * (patch_size - len(patch))
|
94 |
return patch
|
@@ -97,31 +97,46 @@ class Patchilizer:
|
|
97 |
"""
|
98 |
Convert a patch into a bar.
|
99 |
"""
|
100 |
-
return
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
def encode(
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
"""
|
104 |
Encode music into patches of specified length.
|
105 |
"""
|
106 |
-
lines = unidecode(abc_code).split(
|
107 |
lines = list(filter(None, lines)) # remove empty lines
|
108 |
|
109 |
body = ""
|
110 |
patches = []
|
111 |
|
112 |
for line in lines:
|
113 |
-
if len(line) > 1 and (
|
|
|
|
|
114 |
if body:
|
115 |
bars = self.split_bars(body)
|
116 |
patches.extend(
|
117 |
-
self.bar2patch(
|
|
|
|
|
|
|
118 |
)
|
119 |
body = ""
|
120 |
|
121 |
-
patches.append(self.bar2patch(line +
|
122 |
|
123 |
else:
|
124 |
-
body += line +
|
125 |
|
126 |
if body:
|
127 |
patches.extend(
|
@@ -129,10 +144,8 @@ class Patchilizer:
|
|
129 |
)
|
130 |
|
131 |
if add_special_patches:
|
132 |
-
bos_patch = [self.bos_token_id] *
|
133 |
-
|
134 |
-
eos_patch = [self.bos_token_id] + \
|
135 |
-
[self.eos_token_id] * (patch_size-1)
|
136 |
patches = [bos_patch] + patches + [eos_patch]
|
137 |
|
138 |
return patches[:patch_length]
|
@@ -141,12 +154,12 @@ class Patchilizer:
|
|
141 |
"""
|
142 |
Decode patches into music.
|
143 |
"""
|
144 |
-
return
|
145 |
|
146 |
|
147 |
class PatchLevelDecoder(PreTrainedModel):
|
148 |
"""
|
149 |
-
An Patch-level Decoder model for generating patch features in an auto-regressive manner.
|
150 |
It inherits PreTrainedModel from transformers.
|
151 |
"""
|
152 |
|
@@ -171,7 +184,7 @@ class PatchLevelDecoder(PreTrainedModel):
|
|
171 |
|
172 |
class CharLevelDecoder(PreTrainedModel):
|
173 |
"""
|
174 |
-
A Char-level Decoder model for generating the characters within each bar patch sequentially.
|
175 |
It inherits PreTrainedModel from transformers.
|
176 |
"""
|
177 |
|
@@ -182,7 +195,12 @@ class CharLevelDecoder(PreTrainedModel):
|
|
182 |
self.eos_token_id = 2
|
183 |
self.base = GPT2LMHeadModel(config)
|
184 |
|
185 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
186 |
"""
|
187 |
The forward pass of the char-level decoder model.
|
188 |
:param encoded_patches: the encoded patches
|
@@ -198,7 +216,10 @@ class CharLevelDecoder(PreTrainedModel):
|
|
198 |
target_masks = target_masks.masked_fill_(labels == -100, 0)
|
199 |
|
200 |
# select patches
|
201 |
-
if
|
|
|
|
|
|
|
202 |
indices = list(range(len(target_patches)))
|
203 |
random.shuffle(indices)
|
204 |
selected_indices = sorted(indices[:patch_sampling_batch_size])
|
@@ -210,20 +231,16 @@ class CharLevelDecoder(PreTrainedModel):
|
|
210 |
|
211 |
# get input embeddings
|
212 |
inputs_embeds = torch.nn.functional.embedding(
|
213 |
-
target_patches,
|
214 |
-
self.base.transformer.wte.weight
|
215 |
)
|
216 |
|
217 |
# concatenate the encoded patches with the input embeddings
|
218 |
inputs_embeds = torch.cat(
|
219 |
-
(encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]),
|
220 |
-
dim=1
|
221 |
)
|
222 |
|
223 |
return self.base(
|
224 |
-
inputs_embeds=inputs_embeds,
|
225 |
-
attention_mask=target_masks,
|
226 |
-
labels=labels
|
227 |
)
|
228 |
|
229 |
def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
|
@@ -237,10 +254,7 @@ class CharLevelDecoder(PreTrainedModel):
|
|
237 |
tokens = tokens.reshape(1, -1)
|
238 |
|
239 |
# Get input embeddings
|
240 |
-
tokens = torch.nn.functional.embedding(
|
241 |
-
tokens,
|
242 |
-
self.base.transformer.wte.weight
|
243 |
-
)
|
244 |
|
245 |
# Concatenate the encoded patch with the input embeddings
|
246 |
tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
|
@@ -249,17 +263,14 @@ class CharLevelDecoder(PreTrainedModel):
|
|
249 |
outputs = self.base(inputs_embeds=tokens)
|
250 |
|
251 |
# Get probabilities of next token
|
252 |
-
probs = torch.nn.functional.softmax(
|
253 |
-
outputs.logits.squeeze(0)[-1],
|
254 |
-
dim=-1
|
255 |
-
)
|
256 |
|
257 |
return probs
|
258 |
|
259 |
|
260 |
class TunesFormer(PreTrainedModel):
|
261 |
"""
|
262 |
-
TunesFormer is a hierarchical music generation model based on bar patching.
|
263 |
It includes a patch-level decoder and a character-level decoder.
|
264 |
It inherits PreTrainedModel from transformers.
|
265 |
"""
|
@@ -271,18 +282,14 @@ class TunesFormer(PreTrainedModel):
|
|
271 |
self.eos_token_id = 2
|
272 |
if share_weights:
|
273 |
max_layers = max(
|
274 |
-
encoder_config.num_hidden_layers,
|
275 |
-
decoder_config.num_hidden_layers
|
276 |
)
|
277 |
|
278 |
-
max_context_size = max(
|
279 |
-
encoder_config.max_length,
|
280 |
-
decoder_config.max_length
|
281 |
-
)
|
282 |
|
283 |
max_position_embeddings = max(
|
284 |
encoder_config.max_position_embeddings,
|
285 |
-
decoder_config.max_position_embeddings
|
286 |
)
|
287 |
|
288 |
encoder_config.num_hidden_layers = max_layers
|
@@ -298,17 +305,24 @@ class TunesFormer(PreTrainedModel):
|
|
298 |
if share_weights:
|
299 |
self.patch_level_decoder.base = self.char_level_decoder.base.transformer
|
300 |
|
301 |
-
def forward(
|
|
|
|
|
|
|
|
|
302 |
"""
|
303 |
The forward pass of the TunesFormer model.
|
304 |
:param patches: the patches to be both encoded and decoded
|
305 |
:return: the decoded patches
|
306 |
"""
|
307 |
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
308 |
-
encoded_patches = self.patch_level_decoder(
|
309 |
-
patches)["last_hidden_state"]
|
310 |
|
311 |
-
return self.char_level_decoder(
|
|
|
|
|
|
|
|
|
312 |
|
313 |
def generate(
|
314 |
self,
|
@@ -317,7 +331,7 @@ class TunesFormer(PreTrainedModel):
|
|
317 |
top_p: float = 1,
|
318 |
top_k: int = 0,
|
319 |
temperature: float = 1,
|
320 |
-
seed: int = None
|
321 |
):
|
322 |
"""
|
323 |
The generate function for generating patches based on patches.
|
@@ -325,8 +339,7 @@ class TunesFormer(PreTrainedModel):
|
|
325 |
:return: the generated patches
|
326 |
"""
|
327 |
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
328 |
-
encoded_patches = self.patch_level_decoder(
|
329 |
-
patches)["last_hidden_state"]
|
330 |
|
331 |
if tokens == None:
|
332 |
tokens = torch.tensor([self.bos_token_id], device=self.device)
|
@@ -342,19 +355,17 @@ class TunesFormer(PreTrainedModel):
|
|
342 |
else:
|
343 |
n_seed = None
|
344 |
|
345 |
-
prob =
|
346 |
-
encoded_patches[0][-1],
|
347 |
-
|
348 |
-
|
|
|
|
|
349 |
|
350 |
prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
|
351 |
prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
|
352 |
|
353 |
-
token = temperature_sampling(
|
354 |
-
prob,
|
355 |
-
temperature=temperature,
|
356 |
-
seed=n_seed
|
357 |
-
)
|
358 |
|
359 |
generated_patch.append(token)
|
360 |
if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
|
@@ -362,8 +373,7 @@ class TunesFormer(PreTrainedModel):
|
|
362 |
|
363 |
else:
|
364 |
tokens = torch.cat(
|
365 |
-
(tokens, torch.tensor([token], device=self.device)),
|
366 |
-
dim=0
|
367 |
)
|
368 |
|
369 |
return generated_patch, n_seed
|
@@ -374,8 +384,9 @@ class PatchilizedData(Dataset):
|
|
374 |
self.texts = []
|
375 |
|
376 |
for item in tqdm(items):
|
377 |
-
text = item[
|
378 |
-
|
|
|
379 |
input_patch = patchilizer.encode(text, add_special_patches=True)
|
380 |
input_patch = torch.tensor(input_patch)
|
381 |
if torch.sum(input_patch) != 0:
|
|
|
35 |
def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
|
36 |
import time
|
37 |
import requests
|
38 |
+
|
39 |
try:
|
40 |
response = requests.get(url, stream=True)
|
41 |
+
total_size = int(response.headers.get("content-length", 0))
|
42 |
chunk_size = 1024
|
43 |
|
44 |
+
with open(filename, "wb") as file, tqdm(
|
45 |
desc=f"Downloading weights to '{filename}'...",
|
46 |
total=total_size,
|
47 |
+
unit="B",
|
48 |
unit_scale=True,
|
49 |
unit_divisor=1024,
|
50 |
) as bar:
|
|
|
52 |
size = file.write(data)
|
53 |
bar.update(size)
|
54 |
|
55 |
+
except Exception as e:
|
56 |
print(f"Error: {e}")
|
57 |
time.sleep(3)
|
58 |
download(filename, ZH_WEIGHT_URL)
|
|
|
60 |
|
61 |
class Patchilizer:
|
62 |
"""
|
63 |
+
A class for converting music bars to patches and vice versa.
|
64 |
"""
|
65 |
|
66 |
def __init__(self):
|
|
|
74 |
"""
|
75 |
Split a body of music into individual bars.
|
76 |
"""
|
77 |
+
bars = re.split(self.regexPattern, "".join(body))
|
78 |
bars = list(filter(None, bars))
|
79 |
# remove empty strings
|
80 |
if bars[0] in self.delimiters:
|
|
|
88 |
"""
|
89 |
Convert a bar into a patch of specified length.
|
90 |
"""
|
91 |
+
patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
|
|
|
92 |
patch = patch[:patch_size]
|
93 |
patch += [self.pad_token_id] * (patch_size - len(patch))
|
94 |
return patch
|
|
|
97 |
"""
|
98 |
Convert a patch into a bar.
|
99 |
"""
|
100 |
+
return "".join(
|
101 |
+
chr(idx) if idx > self.eos_token_id else ""
|
102 |
+
for idx in patch
|
103 |
+
if idx != self.eos_token_id
|
104 |
+
)
|
105 |
|
106 |
+
def encode(
|
107 |
+
self,
|
108 |
+
abc_code,
|
109 |
+
patch_length=PATCH_LENGTH,
|
110 |
+
patch_size=PATCH_SIZE,
|
111 |
+
add_special_patches=False,
|
112 |
+
):
|
113 |
"""
|
114 |
Encode music into patches of specified length.
|
115 |
"""
|
116 |
+
lines = unidecode(abc_code).split("\n")
|
117 |
lines = list(filter(None, lines)) # remove empty lines
|
118 |
|
119 |
body = ""
|
120 |
patches = []
|
121 |
|
122 |
for line in lines:
|
123 |
+
if len(line) > 1 and (
|
124 |
+
(line[0].isalpha() and line[1] == ":") or line.startswith("%%score")
|
125 |
+
):
|
126 |
if body:
|
127 |
bars = self.split_bars(body)
|
128 |
patches.extend(
|
129 |
+
self.bar2patch(
|
130 |
+
bar + "\n" if idx == len(bars) - 1 else bar, patch_size
|
131 |
+
)
|
132 |
+
for idx, bar in enumerate(bars)
|
133 |
)
|
134 |
body = ""
|
135 |
|
136 |
+
patches.append(self.bar2patch(line + "\n", patch_size))
|
137 |
|
138 |
else:
|
139 |
+
body += line + "\n"
|
140 |
|
141 |
if body:
|
142 |
patches.extend(
|
|
|
144 |
)
|
145 |
|
146 |
if add_special_patches:
|
147 |
+
bos_patch = [self.bos_token_id] * (patch_size - 1) + [self.eos_token_id]
|
148 |
+
eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size - 1)
|
|
|
|
|
149 |
patches = [bos_patch] + patches + [eos_patch]
|
150 |
|
151 |
return patches[:patch_length]
|
|
|
154 |
"""
|
155 |
Decode patches into music.
|
156 |
"""
|
157 |
+
return "".join(self.patch2bar(patch) for patch in patches)
|
158 |
|
159 |
|
160 |
class PatchLevelDecoder(PreTrainedModel):
|
161 |
"""
|
162 |
+
An Patch-level Decoder model for generating patch features in an auto-regressive manner.
|
163 |
It inherits PreTrainedModel from transformers.
|
164 |
"""
|
165 |
|
|
|
184 |
|
185 |
class CharLevelDecoder(PreTrainedModel):
|
186 |
"""
|
187 |
+
A Char-level Decoder model for generating the characters within each bar patch sequentially.
|
188 |
It inherits PreTrainedModel from transformers.
|
189 |
"""
|
190 |
|
|
|
195 |
self.eos_token_id = 2
|
196 |
self.base = GPT2LMHeadModel(config)
|
197 |
|
198 |
+
def forward(
|
199 |
+
self,
|
200 |
+
encoded_patches: torch.Tensor,
|
201 |
+
target_patches: torch.Tensor,
|
202 |
+
patch_sampling_batch_size: int,
|
203 |
+
):
|
204 |
"""
|
205 |
The forward pass of the char-level decoder model.
|
206 |
:param encoded_patches: the encoded patches
|
|
|
216 |
target_masks = target_masks.masked_fill_(labels == -100, 0)
|
217 |
|
218 |
# select patches
|
219 |
+
if (
|
220 |
+
patch_sampling_batch_size != 0
|
221 |
+
and patch_sampling_batch_size < target_patches.shape[0]
|
222 |
+
):
|
223 |
indices = list(range(len(target_patches)))
|
224 |
random.shuffle(indices)
|
225 |
selected_indices = sorted(indices[:patch_sampling_batch_size])
|
|
|
231 |
|
232 |
# get input embeddings
|
233 |
inputs_embeds = torch.nn.functional.embedding(
|
234 |
+
target_patches, self.base.transformer.wte.weight
|
|
|
235 |
)
|
236 |
|
237 |
# concatenate the encoded patches with the input embeddings
|
238 |
inputs_embeds = torch.cat(
|
239 |
+
(encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1
|
|
|
240 |
)
|
241 |
|
242 |
return self.base(
|
243 |
+
inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=labels
|
|
|
|
|
244 |
)
|
245 |
|
246 |
def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
|
|
|
254 |
tokens = tokens.reshape(1, -1)
|
255 |
|
256 |
# Get input embeddings
|
257 |
+
tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
|
|
|
|
|
|
|
258 |
|
259 |
# Concatenate the encoded patch with the input embeddings
|
260 |
tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
|
|
|
263 |
outputs = self.base(inputs_embeds=tokens)
|
264 |
|
265 |
# Get probabilities of next token
|
266 |
+
probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
|
|
|
|
|
|
|
267 |
|
268 |
return probs
|
269 |
|
270 |
|
271 |
class TunesFormer(PreTrainedModel):
|
272 |
"""
|
273 |
+
TunesFormer is a hierarchical music generation model based on bar patching.
|
274 |
It includes a patch-level decoder and a character-level decoder.
|
275 |
It inherits PreTrainedModel from transformers.
|
276 |
"""
|
|
|
282 |
self.eos_token_id = 2
|
283 |
if share_weights:
|
284 |
max_layers = max(
|
285 |
+
encoder_config.num_hidden_layers, decoder_config.num_hidden_layers
|
|
|
286 |
)
|
287 |
|
288 |
+
max_context_size = max(encoder_config.max_length, decoder_config.max_length)
|
|
|
|
|
|
|
289 |
|
290 |
max_position_embeddings = max(
|
291 |
encoder_config.max_position_embeddings,
|
292 |
+
decoder_config.max_position_embeddings,
|
293 |
)
|
294 |
|
295 |
encoder_config.num_hidden_layers = max_layers
|
|
|
305 |
if share_weights:
|
306 |
self.patch_level_decoder.base = self.char_level_decoder.base.transformer
|
307 |
|
308 |
+
def forward(
|
309 |
+
self,
|
310 |
+
patches: torch.Tensor,
|
311 |
+
patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE,
|
312 |
+
):
|
313 |
"""
|
314 |
The forward pass of the TunesFormer model.
|
315 |
:param patches: the patches to be both encoded and decoded
|
316 |
:return: the decoded patches
|
317 |
"""
|
318 |
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
319 |
+
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
|
|
|
320 |
|
321 |
+
return self.char_level_decoder(
|
322 |
+
encoded_patches.squeeze(0)[:-1, :],
|
323 |
+
patches.squeeze(0)[1:, :],
|
324 |
+
patch_sampling_batch_size,
|
325 |
+
)
|
326 |
|
327 |
def generate(
|
328 |
self,
|
|
|
331 |
top_p: float = 1,
|
332 |
top_k: int = 0,
|
333 |
temperature: float = 1,
|
334 |
+
seed: int = None,
|
335 |
):
|
336 |
"""
|
337 |
The generate function for generating patches based on patches.
|
|
|
339 |
:return: the generated patches
|
340 |
"""
|
341 |
patches = patches.reshape(len(patches), -1, PATCH_SIZE)
|
342 |
+
encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
|
|
|
343 |
|
344 |
if tokens == None:
|
345 |
tokens = torch.tensor([self.bos_token_id], device=self.device)
|
|
|
355 |
else:
|
356 |
n_seed = None
|
357 |
|
358 |
+
prob = (
|
359 |
+
self.char_level_decoder.generate(encoded_patches[0][-1], tokens)
|
360 |
+
.cpu()
|
361 |
+
.detach()
|
362 |
+
.numpy()
|
363 |
+
)
|
364 |
|
365 |
prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
|
366 |
prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
|
367 |
|
368 |
+
token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
|
|
|
|
|
|
|
|
|
369 |
|
370 |
generated_patch.append(token)
|
371 |
if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
|
|
|
373 |
|
374 |
else:
|
375 |
tokens = torch.cat(
|
376 |
+
(tokens, torch.tensor([token], device=self.device)), dim=0
|
|
|
377 |
)
|
378 |
|
379 |
return generated_patch, n_seed
|
|
|
384 |
self.texts = []
|
385 |
|
386 |
for item in tqdm(items):
|
387 |
+
text = item["control code"] + "\n".join(
|
388 |
+
item["abc notation"].split("\n")[1:]
|
389 |
+
)
|
390 |
input_patch = patchilizer.encode(text, add_special_patches=True)
|
391 |
input_patch = torch.tensor(input_patch)
|
392 |
if torch.sum(input_patch) != 0:
|