from copy import deepcopy from pathlib import Path from random import shuffle from torch import Tensor, argmax from torch.utils.data import DataLoader from torch.cuda import is_available as cuda_available, is_bf16_supported from torch.backends.mps import is_available as mps_available from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, MistralForCausalLM from transformers.trainer_utils import set_seed from evaluate import load as load_metric from miditok import REMI, TokenizerConfig from miditok.pytorch_data import DatasetTok, DataCollator from tqdm import tqdm # Our tokenizer's configuration PITCH_RANGE = (21, 109) BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1} NUM_VELOCITIES = 24 SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"] USE_CHORDS = False USE_RESTS = False USE_TEMPOS = True USE_TIME_SIGNATURE = False USE_PROGRAMS = False NUM_TEMPOS = 32 TEMPO_RANGE = (50, 200) # (min_tempo, max_tempo) TOKENIZER_PARAMS = { "pitch_range": PITCH_RANGE, "beat_res": BEAT_RES, "num_velocities": NUM_VELOCITIES, "special_tokens": SPECIAL_TOKENS, "use_chords": USE_CHORDS, "use_rests": USE_RESTS, "use_tempos": USE_TEMPOS, "use_time_signatures": USE_TIME_SIGNATURE, "use_programs": USE_PROGRAMS, "num_tempos": NUM_TEMPOS, "tempo_range": TEMPO_RANGE, } config = TokenizerConfig(**TOKENIZER_PARAMS) # Seed set_seed(777) # Creates the tokenizer tokenizer = REMI.from_pretrained("sunsetsobserver/MIDI") midi_paths = list(Path('input').glob('**/*.mid')) + list(Path('input').glob('**/*.midi')) """ list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) """ # Loads tokens and create data collator kwargs_dataset = {"min_seq_len": 10, "max_seq_len": 1024, "tokenizer": tokenizer} dataset_test = DatasetTok(midi_paths, **kwargs_dataset) collator = DataCollator( tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"] ) # Creates model using the correct configuration model = MistralForCausalLM.from_pretrained("./runs") collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True) (gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True) generation_config = GenerationConfig( max_new_tokens=512, # extends samples by 512 tokens num_beams=1, # no beam search do_sample=True, # but sample instead temperature=0.9, top_k=15, top_p=0.95, epsilon_cutoff=3e-4, eta_cutoff=1e-3, ) # Here the sequences are padded to the left, so that the last token along the time dimension # is always the last token of each seq, allowing to efficiently generate by batch collator.pad_on_left = True collator.eos_token = None dataloader_test = DataLoader(dataset_test, batch_size=1, collate_fn=collator) model.eval() count = 0 for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T) res = model.generate( inputs=batch["input_ids"].to(model.device), attention_mask=batch["attention_mask"].to(model.device), generation_config=generation_config) # (N,T) # Saves the generated music, as MIDI files and tokens (json) for prompt, continuation in zip(batch["input_ids"], res): generated = continuation[len(prompt):] midi = tokenizer.tokens_to_midi([deepcopy(generated.tolist())]) tokens = [generated, prompt, continuation] # list compr. as seqs of dif. lengths tokens = [seq.tolist() for seq in tokens] for tok_seq in tokens[1:]: _midi = tokenizer.tokens_to_midi([deepcopy(tok_seq)]) midi.tracks.append(_midi.tracks[0]) midi.tracks[0].name = f'Continuation of original sample ({len(generated)} tokens)' midi.tracks[1].name = f'Original sample ({len(prompt)} tokens)' midi.tracks[2].name = f'Original sample and continuation' midi.dump_midi(gen_results_path / f'{count}.mid') tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json') count += 1