sunsetsobserver commited on
Commit
7bfacad
1 Parent(s): ae3131d

Prep for ALICE training

Browse files
Files changed (2) hide show
  1. Maestro/.DS_Store +0 -0
  2. train.py +1 -44
Maestro/.DS_Store CHANGED
Binary files a/Maestro/.DS_Store and b/Maestro/.DS_Store differ
 
train.py CHANGED
@@ -182,47 +182,4 @@ trainer.save_model() # Saves the tokenizer too
182
  trainer.log_metrics("train", train_result.metrics)
183
  trainer.save_metrics("train", train_result.metrics)
184
  trainer.save_state()
185
-
186
-
187
- (gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
188
- generation_config = GenerationConfig(
189
- max_new_tokens=512, # extends samples by 512 tokens
190
- num_beams=1, # no beam search
191
- do_sample=True, # but sample instead
192
- temperature=0.9,
193
- top_k=15,
194
- top_p=0.95,
195
- epsilon_cutoff=3e-4,
196
- eta_cutoff=1e-3,
197
- pad_token_id=config.padding_token_id,
198
- )
199
-
200
- # Here the sequences are padded to the left, so that the last token along the time dimension
201
- # is always the last token of each seq, allowing to efficiently generate by batch
202
- collator.pad_on_left = True
203
- collator.eos_token = None
204
- dataloader_test = DataLoader(dataset_test, batch_size=16, collate_fn=collator)
205
- model.eval()
206
- count = 0
207
- for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)
208
- res = model.generate(
209
- inputs=batch["input_ids"].to(model.device),
210
- attention_mask=batch["attention_mask"].to(model.device),
211
- generation_config=generation_config) # (N,T)
212
-
213
- # Saves the generated music, as MIDI files and tokens (json)
214
- for prompt, continuation in zip(batch["input_ids"], res):
215
- generated = continuation[len(prompt):]
216
- midi = tokenizer.tokens_to_midi([deepcopy(generated.tolist())])
217
- tokens = [generated, prompt, continuation] # list compr. as seqs of dif. lengths
218
- tokens = [seq.tolist() for seq in tokens]
219
- for tok_seq in tokens[1:]:
220
- _midi = tokenizer.tokens_to_midi([deepcopy(tok_seq)])
221
- midi.instruments.append(_midi.instruments[0])
222
- midi.instruments[0].name = f'Continuation of original sample ({len(generated)} tokens)'
223
- midi.instruments[1].name = f'Original sample ({len(prompt)} tokens)'
224
- midi.instruments[2].name = f'Original sample and continuation'
225
- midi.dump_midi(gen_results_path / f'{count}.mid')
226
- tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')
227
-
228
- count += 1
 
182
  trainer.log_metrics("train", train_result.metrics)
183
  trainer.save_metrics("train", train_result.metrics)
184
  trainer.save_state()
185
+ trainer.push_to_hub()