|
""" Code by Nathan Fradet https://github.com/Natooz, reworked by Adam Łukawski https://github.com/sunsetsobserver """ |
|
|
|
from copy import deepcopy |
|
from pathlib import Path |
|
from random import shuffle |
|
|
|
from torch import Tensor, argmax |
|
from torch.utils.data import DataLoader |
|
from torch.cuda import is_available as cuda_available, is_bf16_supported |
|
from torch.backends.mps import is_available as mps_available |
|
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, AutoModel |
|
from transformers.trainer_utils import set_seed |
|
from evaluate import load as load_metric |
|
from miditok import REMI, TokenizerConfig |
|
from miditok.pytorch_data import DatasetTok, DataCollator |
|
from tqdm import tqdm |
|
|
|
|
|
set_seed(777) |
|
|
|
|
|
tokenizer = REMI.from_pretrained("sunsetsobserver/MIDI") |
|
|
|
|
|
midi_paths = list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) |
|
|
|
|
|
total_num_files = len(midi_paths) |
|
num_files_valid = round(total_num_files * 0.2) |
|
num_files_test = round(total_num_files * 0.1) |
|
shuffle(midi_paths) |
|
midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test] |
|
|
|
|
|
kwargs_dataset = {"min_seq_len": 256, "max_seq_len": 1024, "tokenizer": tokenizer} |
|
dataset_test = DatasetTok(midi_paths_test, **kwargs_dataset) |
|
collator = DataCollator( |
|
tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"] |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("sunsetsobserver/MIDI/runs") |
|
|
|
collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True) |
|
|
|
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True) |
|
generation_config = GenerationConfig( |
|
max_new_tokens=512, |
|
num_beams=1, |
|
do_sample=True, |
|
temperature=0.9, |
|
top_k=15, |
|
top_p=0.95, |
|
epsilon_cutoff=3e-4, |
|
eta_cutoff=1e-3, |
|
pad_token_id=tokenizer.padding_token_id, |
|
) |
|
|
|
|
|
|
|
collator.pad_on_left = True |
|
collator.eos_token = None |
|
dataloader_test = DataLoader(dataset_test, batch_size=16, collate_fn=collator) |
|
model.eval() |
|
count = 0 |
|
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): |
|
res = model.generate( |
|
inputs=batch["input_ids"].to(model.device), |
|
attention_mask=batch["attention_mask"].to(model.device), |
|
generation_config=generation_config) |
|
|
|
|
|
for prompt, continuation in zip(batch["input_ids"], res): |
|
generated = continuation[len(prompt):] |
|
midi = tokenizer.tokens_to_midi([deepcopy(generated.tolist())]) |
|
tokens = [generated, prompt, continuation] |
|
tokens = [seq.tolist() for seq in tokens] |
|
for tok_seq in tokens[1:]: |
|
_midi = tokenizer.tokens_to_midi([deepcopy(tok_seq)]) |
|
midi.instruments.append(_midi.instruments[0]) |
|
midi.instruments[0].name = f'Continuation of original sample ({len(generated)} tokens)' |
|
midi.instruments[1].name = f'Original sample ({len(prompt)} tokens)' |
|
midi.instruments[2].name = f'Original sample and continuation' |
|
midi.dump_midi(gen_results_path / f'{count}.mid') |
|
tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json') |
|
|
|
count += 1 |