MuGeminorum commited on
Commit
27cf0c7
1 Parent(s): f9ed6a5

add copy btn

Browse files
Files changed (2) hide show
  1. app.py +76 -52
  2. utils.py +78 -67
app.py CHANGED
@@ -10,24 +10,48 @@ from config import *
10
  from convert import *
11
  from transformers import GPT2Config
12
  import warnings
13
- warnings.filterwarnings('ignore')
 
14
 
15
 
16
  def get_args(parser):
17
- parser.add_argument('-num_tunes', type=int, default=1,
18
- help='the number of independently computed returned tunes')
19
- parser.add_argument('-max_patch', type=int, default=128,
20
- help='integer to define the maximum length in tokens of each tune')
21
- parser.add_argument('-top_p', type=float, default=0.8,
22
- help='float to define the tokens that are within the sample operation of text generation')
23
- parser.add_argument('-top_k', type=int, default=8,
24
- help='integer to define the tokens that are within the sample operation of text generation')
25
- parser.add_argument('-temperature', type=float, default=1.2,
26
- help='the temperature of the sampling operation')
27
- parser.add_argument('-seed', type=int, default=None,
28
- help='seed for randomstate')
29
- parser.add_argument('-show_control_code', type=bool,
30
- default=True, help='whether to show control code')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  args = parser.parse_args()
32
 
33
  return args
@@ -40,14 +64,14 @@ def generate_abc(args, region):
40
  num_hidden_layers=PATCH_NUM_LAYERS,
41
  max_length=PATCH_LENGTH,
42
  max_position_embeddings=PATCH_LENGTH,
43
- vocab_size=1
44
  )
45
 
46
  char_config = GPT2Config(
47
  num_hidden_layers=CHAR_NUM_LAYERS,
48
  max_length=PATCH_SIZE,
49
  max_position_embeddings=PATCH_SIZE,
50
- vocab_size=128
51
  )
52
 
53
  model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
@@ -60,8 +84,8 @@ def generate_abc(args, region):
60
  else:
61
  download()
62
 
63
- checkpoint = torch.load(filename, map_location=torch.device('cpu'))
64
- model.load_state_dict(checkpoint['model'])
65
  model = model.to(device)
66
  model.eval()
67
 
@@ -76,20 +100,20 @@ def generate_abc(args, region):
76
  seed = args.seed
77
  show_control_code = args.show_control_code
78
 
79
- print(" HYPERPARAMETERS ".center(60, "#"), '\n')
80
  args = vars(args)
81
 
82
  for key in args.keys():
83
- print(f'{key}: {str(args[key])}')
84
 
85
- print('\n', " OUTPUT TUNES ".center(60, "#"))
86
 
87
  start_time = time.time()
88
 
89
  for i in range(num_tunes):
90
- title_artist = f'T:{region} Fragment\nC:Generated by AI\n'
91
  tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
92
- lines = re.split(r'(\n)', tune)
93
  tune = ""
94
  skip = False
95
  for line in lines:
@@ -104,8 +128,7 @@ def generate_abc(args, region):
104
  skip = True
105
 
106
  input_patches = torch.tensor(
107
- [patchilizer.encode(prompt, add_special_patches=True)[:-1]],
108
- device=device
109
  )
110
 
111
  if tune == "":
@@ -113,10 +136,10 @@ def generate_abc(args, region):
113
 
114
  else:
115
  prefix = patchilizer.decode(input_patches[0])
116
- remaining_tokens = prompt[len(prefix):]
117
  tokens = torch.tensor(
118
- [patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens],
119
- device=device
120
  )
121
 
122
  while input_patches.shape[1] < max_patch:
@@ -126,7 +149,7 @@ def generate_abc(args, region):
126
  top_p=top_p,
127
  top_k=top_k,
128
  temperature=temperature,
129
- seed=seed
130
  )
131
  tokens = None
132
 
@@ -140,17 +163,15 @@ def generate_abc(args, region):
140
  if next_bar == "":
141
  break
142
 
143
- next_bar = remaining_tokens+next_bar
144
  remaining_tokens = ""
145
 
146
  predicted_patch = torch.tensor(
147
- patchilizer.bar2patch(next_bar),
148
- device=device
149
  ).unsqueeze(0)
150
 
151
  input_patches = torch.cat(
152
- [input_patches, predicted_patch.unsqueeze(0)],
153
- dim=1
154
  )
155
 
156
  else:
@@ -160,11 +181,11 @@ def generate_abc(args, region):
160
  print("\n")
161
 
162
  print("Generation time: {:.2f} seconds".format(time.time() - start_time))
163
- create_dir('./tmp')
164
  timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
165
- out_midi = abc_to_midi(tunes, f'./tmp/[{region}]{timestamp}.mid')
166
- out_xml = abc_to_musicxml(tunes, f'./tmp/[{region}]{timestamp}.musicxml')
167
- out_mxl = musicxml_to_mxl(f'./tmp/[{region}]{timestamp}.musicxml')
168
  pdf_file, jpg_file = mxl2jpg(out_mxl)
169
  wav_file = midi2wav(out_midi)
170
 
@@ -172,8 +193,8 @@ def generate_abc(args, region):
172
 
173
 
174
  def inference(region):
175
- if os.path.exists('./tmp'):
176
- shutil.rmtree('./tmp')
177
 
178
  parser = argparse.ArgumentParser()
179
  args = get_args(parser)
@@ -184,30 +205,33 @@ with gr.Blocks() as demo:
184
  with gr.Row():
185
  with gr.Column():
186
  region_opt = gr.Dropdown(
187
- choices=[
188
- 'Mondstadt', 'Liyue', 'Inazuma', 'Sumeru', 'Fontaine'
189
- ],
190
- value='Mondstadt',
191
- label='Region genre'
192
  )
193
  gen_btn = gr.Button("Generate")
194
 
195
  with gr.Column():
196
- wav_output = gr.Audio(label='Audio', type='filepath')
197
  dld_midi = gr.components.File(label="Download MIDI")
198
  pdf_score = gr.components.File(label="Download PDF score")
199
  dld_xml = gr.components.File(label="Download MusicXML")
200
  dld_mxl = gr.components.File(label="Download MXL")
201
- abc_output = gr.TextArea(label='abc score')
202
- img_score = gr.Image(label='Staff', type='filepath')
203
 
204
  gen_btn.click(
205
  inference,
206
  inputs=region_opt,
207
  outputs=[
208
- abc_output, dld_midi, pdf_score,
209
- dld_xml, dld_mxl, img_score, wav_output
210
- ]
 
 
 
 
 
211
  )
212
 
213
  demo.launch(share=True)
 
10
  from convert import *
11
  from transformers import GPT2Config
12
  import warnings
13
+
14
+ warnings.filterwarnings("ignore")
15
 
16
 
17
  def get_args(parser):
18
+ parser.add_argument(
19
+ "-num_tunes",
20
+ type=int,
21
+ default=1,
22
+ help="the number of independently computed returned tunes",
23
+ )
24
+ parser.add_argument(
25
+ "-max_patch",
26
+ type=int,
27
+ default=128,
28
+ help="integer to define the maximum length in tokens of each tune",
29
+ )
30
+ parser.add_argument(
31
+ "-top_p",
32
+ type=float,
33
+ default=0.8,
34
+ help="float to define the tokens that are within the sample operation of text generation",
35
+ )
36
+ parser.add_argument(
37
+ "-top_k",
38
+ type=int,
39
+ default=8,
40
+ help="integer to define the tokens that are within the sample operation of text generation",
41
+ )
42
+ parser.add_argument(
43
+ "-temperature",
44
+ type=float,
45
+ default=1.2,
46
+ help="the temperature of the sampling operation",
47
+ )
48
+ parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
49
+ parser.add_argument(
50
+ "-show_control_code",
51
+ type=bool,
52
+ default=True,
53
+ help="whether to show control code",
54
+ )
55
  args = parser.parse_args()
56
 
57
  return args
 
64
  num_hidden_layers=PATCH_NUM_LAYERS,
65
  max_length=PATCH_LENGTH,
66
  max_position_embeddings=PATCH_LENGTH,
67
+ vocab_size=1,
68
  )
69
 
70
  char_config = GPT2Config(
71
  num_hidden_layers=CHAR_NUM_LAYERS,
72
  max_length=PATCH_SIZE,
73
  max_position_embeddings=PATCH_SIZE,
74
+ vocab_size=128,
75
  )
76
 
77
  model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
 
84
  else:
85
  download()
86
 
87
+ checkpoint = torch.load(filename, map_location=torch.device("cpu"))
88
+ model.load_state_dict(checkpoint["model"])
89
  model = model.to(device)
90
  model.eval()
91
 
 
100
  seed = args.seed
101
  show_control_code = args.show_control_code
102
 
103
+ print(" HYPERPARAMETERS ".center(60, "#"), "\n")
104
  args = vars(args)
105
 
106
  for key in args.keys():
107
+ print(f"{key}: {str(args[key])}")
108
 
109
+ print("\n", " OUTPUT TUNES ".center(60, "#"))
110
 
111
  start_time = time.time()
112
 
113
  for i in range(num_tunes):
114
+ title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
115
  tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
116
+ lines = re.split(r"(\n)", tune)
117
  tune = ""
118
  skip = False
119
  for line in lines:
 
128
  skip = True
129
 
130
  input_patches = torch.tensor(
131
+ [patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=device
 
132
  )
133
 
134
  if tune == "":
 
136
 
137
  else:
138
  prefix = patchilizer.decode(input_patches[0])
139
+ remaining_tokens = prompt[len(prefix) :]
140
  tokens = torch.tensor(
141
+ [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
142
+ device=device,
143
  )
144
 
145
  while input_patches.shape[1] < max_patch:
 
149
  top_p=top_p,
150
  top_k=top_k,
151
  temperature=temperature,
152
+ seed=seed,
153
  )
154
  tokens = None
155
 
 
163
  if next_bar == "":
164
  break
165
 
166
+ next_bar = remaining_tokens + next_bar
167
  remaining_tokens = ""
168
 
169
  predicted_patch = torch.tensor(
170
+ patchilizer.bar2patch(next_bar), device=device
 
171
  ).unsqueeze(0)
172
 
173
  input_patches = torch.cat(
174
+ [input_patches, predicted_patch.unsqueeze(0)], dim=1
 
175
  )
176
 
177
  else:
 
181
  print("\n")
182
 
183
  print("Generation time: {:.2f} seconds".format(time.time() - start_time))
184
+ create_dir("./tmp")
185
  timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
186
+ out_midi = abc_to_midi(tunes, f"./tmp/[{region}]{timestamp}.mid")
187
+ out_xml = abc_to_musicxml(tunes, f"./tmp/[{region}]{timestamp}.musicxml")
188
+ out_mxl = musicxml_to_mxl(f"./tmp/[{region}]{timestamp}.musicxml")
189
  pdf_file, jpg_file = mxl2jpg(out_mxl)
190
  wav_file = midi2wav(out_midi)
191
 
 
193
 
194
 
195
  def inference(region):
196
+ if os.path.exists("./tmp"):
197
+ shutil.rmtree("./tmp")
198
 
199
  parser = argparse.ArgumentParser()
200
  args = get_args(parser)
 
205
  with gr.Row():
206
  with gr.Column():
207
  region_opt = gr.Dropdown(
208
+ choices=["Mondstadt", "Liyue", "Inazuma", "Sumeru", "Fontaine"],
209
+ value="Mondstadt",
210
+ label="Region genre",
 
 
211
  )
212
  gen_btn = gr.Button("Generate")
213
 
214
  with gr.Column():
215
+ wav_output = gr.Audio(label="Audio", type="filepath")
216
  dld_midi = gr.components.File(label="Download MIDI")
217
  pdf_score = gr.components.File(label="Download PDF score")
218
  dld_xml = gr.components.File(label="Download MusicXML")
219
  dld_mxl = gr.components.File(label="Download MXL")
220
+ abc_output = gr.Textbox(label="abc score", show_copy_button=True)
221
+ img_score = gr.Image(label="Staff", type="filepath")
222
 
223
  gen_btn.click(
224
  inference,
225
  inputs=region_opt,
226
  outputs=[
227
+ abc_output,
228
+ dld_midi,
229
+ pdf_score,
230
+ dld_xml,
231
+ dld_mxl,
232
+ img_score,
233
+ wav_output,
234
+ ],
235
  )
236
 
237
  demo.launch(share=True)
utils.py CHANGED
@@ -35,15 +35,16 @@ def create_dir(dir_path):
35
  def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
36
  import time
37
  import requests
 
38
  try:
39
  response = requests.get(url, stream=True)
40
- total_size = int(response.headers.get('content-length', 0))
41
  chunk_size = 1024
42
 
43
- with open(filename, 'wb') as file, tqdm(
44
  desc=f"Downloading weights to '{filename}'...",
45
  total=total_size,
46
- unit='B',
47
  unit_scale=True,
48
  unit_divisor=1024,
49
  ) as bar:
@@ -51,7 +52,7 @@ def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
51
  size = file.write(data)
52
  bar.update(size)
53
 
54
- except ConnectionError as e:
55
  print(f"Error: {e}")
56
  time.sleep(3)
57
  download(filename, ZH_WEIGHT_URL)
@@ -59,7 +60,7 @@ def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
59
 
60
  class Patchilizer:
61
  """
62
- A class for converting music bars to patches and vice versa.
63
  """
64
 
65
  def __init__(self):
@@ -73,7 +74,7 @@ class Patchilizer:
73
  """
74
  Split a body of music into individual bars.
75
  """
76
- bars = re.split(self.regexPattern, ''.join(body))
77
  bars = list(filter(None, bars))
78
  # remove empty strings
79
  if bars[0] in self.delimiters:
@@ -87,8 +88,7 @@ class Patchilizer:
87
  """
88
  Convert a bar into a patch of specified length.
89
  """
90
- patch = [self.bos_token_id] + \
91
- [ord(c) for c in bar] + [self.eos_token_id]
92
  patch = patch[:patch_size]
93
  patch += [self.pad_token_id] * (patch_size - len(patch))
94
  return patch
@@ -97,31 +97,46 @@ class Patchilizer:
97
  """
98
  Convert a patch into a bar.
99
  """
100
- return ''.join(chr(idx) if idx > self.eos_token_id else '' for idx in patch if idx != self.eos_token_id)
 
 
 
 
101
 
102
- def encode(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=False):
 
 
 
 
 
 
103
  """
104
  Encode music into patches of specified length.
105
  """
106
- lines = unidecode(abc_code).split('\n')
107
  lines = list(filter(None, lines)) # remove empty lines
108
 
109
  body = ""
110
  patches = []
111
 
112
  for line in lines:
113
- if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%score')):
 
 
114
  if body:
115
  bars = self.split_bars(body)
116
  patches.extend(
117
- self.bar2patch(bar + '\n' if idx == len(bars) - 1 else bar, patch_size) for idx, bar in enumerate(bars)
 
 
 
118
  )
119
  body = ""
120
 
121
- patches.append(self.bar2patch(line + '\n', patch_size))
122
 
123
  else:
124
- body += line + '\n'
125
 
126
  if body:
127
  patches.extend(
@@ -129,10 +144,8 @@ class Patchilizer:
129
  )
130
 
131
  if add_special_patches:
132
- bos_patch = [self.bos_token_id] * \
133
- (patch_size-1) + [self.eos_token_id]
134
- eos_patch = [self.bos_token_id] + \
135
- [self.eos_token_id] * (patch_size-1)
136
  patches = [bos_patch] + patches + [eos_patch]
137
 
138
  return patches[:patch_length]
@@ -141,12 +154,12 @@ class Patchilizer:
141
  """
142
  Decode patches into music.
143
  """
144
- return ''.join(self.patch2bar(patch) for patch in patches)
145
 
146
 
147
  class PatchLevelDecoder(PreTrainedModel):
148
  """
149
- An Patch-level Decoder model for generating patch features in an auto-regressive manner.
150
  It inherits PreTrainedModel from transformers.
151
  """
152
 
@@ -171,7 +184,7 @@ class PatchLevelDecoder(PreTrainedModel):
171
 
172
  class CharLevelDecoder(PreTrainedModel):
173
  """
174
- A Char-level Decoder model for generating the characters within each bar patch sequentially.
175
  It inherits PreTrainedModel from transformers.
176
  """
177
 
@@ -182,7 +195,12 @@ class CharLevelDecoder(PreTrainedModel):
182
  self.eos_token_id = 2
183
  self.base = GPT2LMHeadModel(config)
184
 
185
- def forward(self, encoded_patches: torch.Tensor, target_patches: torch.Tensor, patch_sampling_batch_size: int):
 
 
 
 
 
186
  """
187
  The forward pass of the char-level decoder model.
188
  :param encoded_patches: the encoded patches
@@ -198,7 +216,10 @@ class CharLevelDecoder(PreTrainedModel):
198
  target_masks = target_masks.masked_fill_(labels == -100, 0)
199
 
200
  # select patches
201
- if patch_sampling_batch_size != 0 and patch_sampling_batch_size < target_patches.shape[0]:
 
 
 
202
  indices = list(range(len(target_patches)))
203
  random.shuffle(indices)
204
  selected_indices = sorted(indices[:patch_sampling_batch_size])
@@ -210,20 +231,16 @@ class CharLevelDecoder(PreTrainedModel):
210
 
211
  # get input embeddings
212
  inputs_embeds = torch.nn.functional.embedding(
213
- target_patches,
214
- self.base.transformer.wte.weight
215
  )
216
 
217
  # concatenate the encoded patches with the input embeddings
218
  inputs_embeds = torch.cat(
219
- (encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]),
220
- dim=1
221
  )
222
 
223
  return self.base(
224
- inputs_embeds=inputs_embeds,
225
- attention_mask=target_masks,
226
- labels=labels
227
  )
228
 
229
  def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
@@ -237,10 +254,7 @@ class CharLevelDecoder(PreTrainedModel):
237
  tokens = tokens.reshape(1, -1)
238
 
239
  # Get input embeddings
240
- tokens = torch.nn.functional.embedding(
241
- tokens,
242
- self.base.transformer.wte.weight
243
- )
244
 
245
  # Concatenate the encoded patch with the input embeddings
246
  tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
@@ -249,17 +263,14 @@ class CharLevelDecoder(PreTrainedModel):
249
  outputs = self.base(inputs_embeds=tokens)
250
 
251
  # Get probabilities of next token
252
- probs = torch.nn.functional.softmax(
253
- outputs.logits.squeeze(0)[-1],
254
- dim=-1
255
- )
256
 
257
  return probs
258
 
259
 
260
  class TunesFormer(PreTrainedModel):
261
  """
262
- TunesFormer is a hierarchical music generation model based on bar patching.
263
  It includes a patch-level decoder and a character-level decoder.
264
  It inherits PreTrainedModel from transformers.
265
  """
@@ -271,18 +282,14 @@ class TunesFormer(PreTrainedModel):
271
  self.eos_token_id = 2
272
  if share_weights:
273
  max_layers = max(
274
- encoder_config.num_hidden_layers,
275
- decoder_config.num_hidden_layers
276
  )
277
 
278
- max_context_size = max(
279
- encoder_config.max_length,
280
- decoder_config.max_length
281
- )
282
 
283
  max_position_embeddings = max(
284
  encoder_config.max_position_embeddings,
285
- decoder_config.max_position_embeddings
286
  )
287
 
288
  encoder_config.num_hidden_layers = max_layers
@@ -298,17 +305,24 @@ class TunesFormer(PreTrainedModel):
298
  if share_weights:
299
  self.patch_level_decoder.base = self.char_level_decoder.base.transformer
300
 
301
- def forward(self, patches: torch.Tensor, patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE):
 
 
 
 
302
  """
303
  The forward pass of the TunesFormer model.
304
  :param patches: the patches to be both encoded and decoded
305
  :return: the decoded patches
306
  """
307
  patches = patches.reshape(len(patches), -1, PATCH_SIZE)
308
- encoded_patches = self.patch_level_decoder(
309
- patches)["last_hidden_state"]
310
 
311
- return self.char_level_decoder(encoded_patches.squeeze(0)[:-1, :], patches.squeeze(0)[1:, :], patch_sampling_batch_size)
 
 
 
 
312
 
313
  def generate(
314
  self,
@@ -317,7 +331,7 @@ class TunesFormer(PreTrainedModel):
317
  top_p: float = 1,
318
  top_k: int = 0,
319
  temperature: float = 1,
320
- seed: int = None
321
  ):
322
  """
323
  The generate function for generating patches based on patches.
@@ -325,8 +339,7 @@ class TunesFormer(PreTrainedModel):
325
  :return: the generated patches
326
  """
327
  patches = patches.reshape(len(patches), -1, PATCH_SIZE)
328
- encoded_patches = self.patch_level_decoder(
329
- patches)["last_hidden_state"]
330
 
331
  if tokens == None:
332
  tokens = torch.tensor([self.bos_token_id], device=self.device)
@@ -342,19 +355,17 @@ class TunesFormer(PreTrainedModel):
342
  else:
343
  n_seed = None
344
 
345
- prob = self.char_level_decoder.generate(
346
- encoded_patches[0][-1],
347
- tokens
348
- ).cpu().detach().numpy()
 
 
349
 
350
  prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
351
  prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
352
 
353
- token = temperature_sampling(
354
- prob,
355
- temperature=temperature,
356
- seed=n_seed
357
- )
358
 
359
  generated_patch.append(token)
360
  if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
@@ -362,8 +373,7 @@ class TunesFormer(PreTrainedModel):
362
 
363
  else:
364
  tokens = torch.cat(
365
- (tokens, torch.tensor([token], device=self.device)),
366
- dim=0
367
  )
368
 
369
  return generated_patch, n_seed
@@ -374,8 +384,9 @@ class PatchilizedData(Dataset):
374
  self.texts = []
375
 
376
  for item in tqdm(items):
377
- text = item['control code'] + \
378
- "\n".join(item['abc notation'].split('\n')[1:])
 
379
  input_patch = patchilizer.encode(text, add_special_patches=True)
380
  input_patch = torch.tensor(input_patch)
381
  if torch.sum(input_patch) != 0:
 
35
  def download(filename=WEIGHT_PATH, url=WEIGHT_URL):
36
  import time
37
  import requests
38
+
39
  try:
40
  response = requests.get(url, stream=True)
41
+ total_size = int(response.headers.get("content-length", 0))
42
  chunk_size = 1024
43
 
44
+ with open(filename, "wb") as file, tqdm(
45
  desc=f"Downloading weights to '{filename}'...",
46
  total=total_size,
47
+ unit="B",
48
  unit_scale=True,
49
  unit_divisor=1024,
50
  ) as bar:
 
52
  size = file.write(data)
53
  bar.update(size)
54
 
55
+ except Exception as e:
56
  print(f"Error: {e}")
57
  time.sleep(3)
58
  download(filename, ZH_WEIGHT_URL)
 
60
 
61
  class Patchilizer:
62
  """
63
+ A class for converting music bars to patches and vice versa.
64
  """
65
 
66
  def __init__(self):
 
74
  """
75
  Split a body of music into individual bars.
76
  """
77
+ bars = re.split(self.regexPattern, "".join(body))
78
  bars = list(filter(None, bars))
79
  # remove empty strings
80
  if bars[0] in self.delimiters:
 
88
  """
89
  Convert a bar into a patch of specified length.
90
  """
91
+ patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
 
92
  patch = patch[:patch_size]
93
  patch += [self.pad_token_id] * (patch_size - len(patch))
94
  return patch
 
97
  """
98
  Convert a patch into a bar.
99
  """
100
+ return "".join(
101
+ chr(idx) if idx > self.eos_token_id else ""
102
+ for idx in patch
103
+ if idx != self.eos_token_id
104
+ )
105
 
106
+ def encode(
107
+ self,
108
+ abc_code,
109
+ patch_length=PATCH_LENGTH,
110
+ patch_size=PATCH_SIZE,
111
+ add_special_patches=False,
112
+ ):
113
  """
114
  Encode music into patches of specified length.
115
  """
116
+ lines = unidecode(abc_code).split("\n")
117
  lines = list(filter(None, lines)) # remove empty lines
118
 
119
  body = ""
120
  patches = []
121
 
122
  for line in lines:
123
+ if len(line) > 1 and (
124
+ (line[0].isalpha() and line[1] == ":") or line.startswith("%%score")
125
+ ):
126
  if body:
127
  bars = self.split_bars(body)
128
  patches.extend(
129
+ self.bar2patch(
130
+ bar + "\n" if idx == len(bars) - 1 else bar, patch_size
131
+ )
132
+ for idx, bar in enumerate(bars)
133
  )
134
  body = ""
135
 
136
+ patches.append(self.bar2patch(line + "\n", patch_size))
137
 
138
  else:
139
+ body += line + "\n"
140
 
141
  if body:
142
  patches.extend(
 
144
  )
145
 
146
  if add_special_patches:
147
+ bos_patch = [self.bos_token_id] * (patch_size - 1) + [self.eos_token_id]
148
+ eos_patch = [self.bos_token_id] + [self.eos_token_id] * (patch_size - 1)
 
 
149
  patches = [bos_patch] + patches + [eos_patch]
150
 
151
  return patches[:patch_length]
 
154
  """
155
  Decode patches into music.
156
  """
157
+ return "".join(self.patch2bar(patch) for patch in patches)
158
 
159
 
160
  class PatchLevelDecoder(PreTrainedModel):
161
  """
162
+ An Patch-level Decoder model for generating patch features in an auto-regressive manner.
163
  It inherits PreTrainedModel from transformers.
164
  """
165
 
 
184
 
185
  class CharLevelDecoder(PreTrainedModel):
186
  """
187
+ A Char-level Decoder model for generating the characters within each bar patch sequentially.
188
  It inherits PreTrainedModel from transformers.
189
  """
190
 
 
195
  self.eos_token_id = 2
196
  self.base = GPT2LMHeadModel(config)
197
 
198
+ def forward(
199
+ self,
200
+ encoded_patches: torch.Tensor,
201
+ target_patches: torch.Tensor,
202
+ patch_sampling_batch_size: int,
203
+ ):
204
  """
205
  The forward pass of the char-level decoder model.
206
  :param encoded_patches: the encoded patches
 
216
  target_masks = target_masks.masked_fill_(labels == -100, 0)
217
 
218
  # select patches
219
+ if (
220
+ patch_sampling_batch_size != 0
221
+ and patch_sampling_batch_size < target_patches.shape[0]
222
+ ):
223
  indices = list(range(len(target_patches)))
224
  random.shuffle(indices)
225
  selected_indices = sorted(indices[:patch_sampling_batch_size])
 
231
 
232
  # get input embeddings
233
  inputs_embeds = torch.nn.functional.embedding(
234
+ target_patches, self.base.transformer.wte.weight
 
235
  )
236
 
237
  # concatenate the encoded patches with the input embeddings
238
  inputs_embeds = torch.cat(
239
+ (encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1
 
240
  )
241
 
242
  return self.base(
243
+ inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=labels
 
 
244
  )
245
 
246
  def generate(self, encoded_patch: torch.Tensor, tokens: torch.Tensor):
 
254
  tokens = tokens.reshape(1, -1)
255
 
256
  # Get input embeddings
257
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
 
 
 
258
 
259
  # Concatenate the encoded patch with the input embeddings
260
  tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
 
263
  outputs = self.base(inputs_embeds=tokens)
264
 
265
  # Get probabilities of next token
266
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
 
 
 
267
 
268
  return probs
269
 
270
 
271
  class TunesFormer(PreTrainedModel):
272
  """
273
+ TunesFormer is a hierarchical music generation model based on bar patching.
274
  It includes a patch-level decoder and a character-level decoder.
275
  It inherits PreTrainedModel from transformers.
276
  """
 
282
  self.eos_token_id = 2
283
  if share_weights:
284
  max_layers = max(
285
+ encoder_config.num_hidden_layers, decoder_config.num_hidden_layers
 
286
  )
287
 
288
+ max_context_size = max(encoder_config.max_length, decoder_config.max_length)
 
 
 
289
 
290
  max_position_embeddings = max(
291
  encoder_config.max_position_embeddings,
292
+ decoder_config.max_position_embeddings,
293
  )
294
 
295
  encoder_config.num_hidden_layers = max_layers
 
305
  if share_weights:
306
  self.patch_level_decoder.base = self.char_level_decoder.base.transformer
307
 
308
+ def forward(
309
+ self,
310
+ patches: torch.Tensor,
311
+ patch_sampling_batch_size: int = PATCH_SAMPLING_BATCH_SIZE,
312
+ ):
313
  """
314
  The forward pass of the TunesFormer model.
315
  :param patches: the patches to be both encoded and decoded
316
  :return: the decoded patches
317
  """
318
  patches = patches.reshape(len(patches), -1, PATCH_SIZE)
319
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
 
320
 
321
+ return self.char_level_decoder(
322
+ encoded_patches.squeeze(0)[:-1, :],
323
+ patches.squeeze(0)[1:, :],
324
+ patch_sampling_batch_size,
325
+ )
326
 
327
  def generate(
328
  self,
 
331
  top_p: float = 1,
332
  top_k: int = 0,
333
  temperature: float = 1,
334
+ seed: int = None,
335
  ):
336
  """
337
  The generate function for generating patches based on patches.
 
339
  :return: the generated patches
340
  """
341
  patches = patches.reshape(len(patches), -1, PATCH_SIZE)
342
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
 
343
 
344
  if tokens == None:
345
  tokens = torch.tensor([self.bos_token_id], device=self.device)
 
355
  else:
356
  n_seed = None
357
 
358
+ prob = (
359
+ self.char_level_decoder.generate(encoded_patches[0][-1], tokens)
360
+ .cpu()
361
+ .detach()
362
+ .numpy()
363
+ )
364
 
365
  prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
366
  prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
367
 
368
+ token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
 
 
 
 
369
 
370
  generated_patch.append(token)
371
  if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
 
373
 
374
  else:
375
  tokens = torch.cat(
376
+ (tokens, torch.tensor([token], device=self.device)), dim=0
 
377
  )
378
 
379
  return generated_patch, n_seed
 
384
  self.texts = []
385
 
386
  for item in tqdm(items):
387
+ text = item["control code"] + "\n".join(
388
+ item["abc notation"].split("\n")[1:]
389
+ )
390
  input_patch = patchilizer.encode(text, add_special_patches=True)
391
  input_patch = torch.tensor(input_patch)
392
  if torch.sum(input_patch) != 0: