sunsetsobserver
commited on
Commit
•
7bfacad
1
Parent(s):
ae3131d
Prep for ALICE training
Browse files- Maestro/.DS_Store +0 -0
- train.py +1 -44
Maestro/.DS_Store
CHANGED
Binary files a/Maestro/.DS_Store and b/Maestro/.DS_Store differ
|
|
train.py
CHANGED
@@ -182,47 +182,4 @@ trainer.save_model() # Saves the tokenizer too
|
|
182 |
trainer.log_metrics("train", train_result.metrics)
|
183 |
trainer.save_metrics("train", train_result.metrics)
|
184 |
trainer.save_state()
|
185 |
-
|
186 |
-
|
187 |
-
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
|
188 |
-
generation_config = GenerationConfig(
|
189 |
-
max_new_tokens=512, # extends samples by 512 tokens
|
190 |
-
num_beams=1, # no beam search
|
191 |
-
do_sample=True, # but sample instead
|
192 |
-
temperature=0.9,
|
193 |
-
top_k=15,
|
194 |
-
top_p=0.95,
|
195 |
-
epsilon_cutoff=3e-4,
|
196 |
-
eta_cutoff=1e-3,
|
197 |
-
pad_token_id=config.padding_token_id,
|
198 |
-
)
|
199 |
-
|
200 |
-
# Here the sequences are padded to the left, so that the last token along the time dimension
|
201 |
-
# is always the last token of each seq, allowing to efficiently generate by batch
|
202 |
-
collator.pad_on_left = True
|
203 |
-
collator.eos_token = None
|
204 |
-
dataloader_test = DataLoader(dataset_test, batch_size=16, collate_fn=collator)
|
205 |
-
model.eval()
|
206 |
-
count = 0
|
207 |
-
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)
|
208 |
-
res = model.generate(
|
209 |
-
inputs=batch["input_ids"].to(model.device),
|
210 |
-
attention_mask=batch["attention_mask"].to(model.device),
|
211 |
-
generation_config=generation_config) # (N,T)
|
212 |
-
|
213 |
-
# Saves the generated music, as MIDI files and tokens (json)
|
214 |
-
for prompt, continuation in zip(batch["input_ids"], res):
|
215 |
-
generated = continuation[len(prompt):]
|
216 |
-
midi = tokenizer.tokens_to_midi([deepcopy(generated.tolist())])
|
217 |
-
tokens = [generated, prompt, continuation] # list compr. as seqs of dif. lengths
|
218 |
-
tokens = [seq.tolist() for seq in tokens]
|
219 |
-
for tok_seq in tokens[1:]:
|
220 |
-
_midi = tokenizer.tokens_to_midi([deepcopy(tok_seq)])
|
221 |
-
midi.instruments.append(_midi.instruments[0])
|
222 |
-
midi.instruments[0].name = f'Continuation of original sample ({len(generated)} tokens)'
|
223 |
-
midi.instruments[1].name = f'Original sample ({len(prompt)} tokens)'
|
224 |
-
midi.instruments[2].name = f'Original sample and continuation'
|
225 |
-
midi.dump_midi(gen_results_path / f'{count}.mid')
|
226 |
-
tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')
|
227 |
-
|
228 |
-
count += 1
|
|
|
182 |
trainer.log_metrics("train", train_result.metrics)
|
183 |
trainer.save_metrics("train", train_result.metrics)
|
184 |
trainer.save_state()
|
185 |
+
trainer.push_to_hub()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|