""" Code by Nathan Fradet https://github.com/Natooz, reworked by Adam Ɓukawski https://github.com/sunsetsobserver """ from copy import deepcopy from pathlib import Path from random import shuffle from torch import Tensor, argmax from torch.utils.data import DataLoader from torch.cuda import is_available as cuda_available, is_bf16_supported from torch.backends.mps import is_available as mps_available from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, AutoModel from transformers.trainer_utils import set_seed from evaluate import load as load_metric from miditok import REMI, TokenizerConfig from miditok.pytorch_data import DatasetTok, DataCollator from tqdm import tqdm # Seed set_seed(777) # Creates the tokenizer tokenizer = AutoTokenizer.from_pretrained("tokenizer") # Trains the tokenizer with Byte Pair Encoding (BPE) to build the vocabulary, here 10k tokens midi_paths = list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) # Split MIDI paths in train/valid/test sets total_num_files = len(midi_paths) num_files_valid = round(total_num_files * 0.2) num_files_test = round(total_num_files * 0.1) shuffle(midi_paths) midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test] # Loads tokens and create data collator kwargs_dataset = {"min_seq_len": 256, "max_seq_len": 1024, "tokenizer": tokenizer} dataset_test = DatasetTok(midi_paths_test, **kwargs_dataset) collator = DataCollator( tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"] ) # Creates model using the correct configuration model = AutoModel.from_pretrained("./runs/model.safetensors") collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True) (gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True) generation_config = GenerationConfig( max_new_tokens=512, # extends samples by 512 tokens num_beams=1, # no beam search do_sample=True, # but sample instead temperature=0.9, top_k=15, top_p=0.95, epsilon_cutoff=3e-4, eta_cutoff=1e-3, pad_token_id=tokenizer.padding_token_id, ) # Here the sequences are padded to the left, so that the last token along the time dimension # is always the last token of each seq, allowing to efficiently generate by batch collator.pad_on_left = True collator.eos_token = None dataloader_test = DataLoader(dataset_test, batch_size=16, collate_fn=collator) model.eval() count = 0 for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T) res = model.generate( inputs=batch["input_ids"].to(model.device), attention_mask=batch["attention_mask"].to(model.device), generation_config=generation_config) # (N,T) # Saves the generated music, as MIDI files and tokens (json) for prompt, continuation in zip(batch["input_ids"], res): generated = continuation[len(prompt):] midi = tokenizer.tokens_to_midi([deepcopy(generated.tolist())]) tokens = [generated, prompt, continuation] # list compr. as seqs of dif. lengths tokens = [seq.tolist() for seq in tokens] for tok_seq in tokens[1:]: _midi = tokenizer.tokens_to_midi([deepcopy(tok_seq)]) midi.instruments.append(_midi.instruments[0]) midi.instruments[0].name = f'Continuation of original sample ({len(generated)} tokens)' midi.instruments[1].name = f'Original sample ({len(prompt)} tokens)' midi.instruments[2].name = f'Original sample and continuation' midi.dump_midi(gen_results_path / f'{count}.mid') tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json') count += 1