|
from copy import deepcopy |
|
from pathlib import Path |
|
from random import shuffle |
|
|
|
from torch import Tensor, argmax |
|
from torch.utils.data import DataLoader |
|
from torch.cuda import is_available as cuda_available, is_bf16_supported |
|
from torch.backends.mps import is_available as mps_available |
|
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, MistralForCausalLM |
|
from transformers.trainer_utils import set_seed |
|
from evaluate import load as load_metric |
|
from miditok import REMI, TokenizerConfig |
|
from miditok.pytorch_data import DatasetTok, DataCollator |
|
from tqdm import tqdm |
|
|
|
|
|
PITCH_RANGE = (21, 109) |
|
BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1} |
|
NUM_VELOCITIES = 24 |
|
SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"] |
|
USE_CHORDS = False |
|
USE_RESTS = False |
|
USE_TEMPOS = True |
|
USE_TIME_SIGNATURE = False |
|
USE_PROGRAMS = False |
|
NUM_TEMPOS = 32 |
|
TEMPO_RANGE = (50, 200) |
|
TOKENIZER_PARAMS = { |
|
"pitch_range": PITCH_RANGE, |
|
"beat_res": BEAT_RES, |
|
"num_velocities": NUM_VELOCITIES, |
|
"special_tokens": SPECIAL_TOKENS, |
|
"use_chords": USE_CHORDS, |
|
"use_rests": USE_RESTS, |
|
"use_tempos": USE_TEMPOS, |
|
"use_time_signatures": USE_TIME_SIGNATURE, |
|
"use_programs": USE_PROGRAMS, |
|
"num_tempos": NUM_TEMPOS, |
|
"tempo_range": TEMPO_RANGE, |
|
} |
|
config = TokenizerConfig(**TOKENIZER_PARAMS) |
|
|
|
|
|
set_seed(777) |
|
|
|
|
|
tokenizer = REMI.from_pretrained("sunsetsobserver/MIDI") |
|
|
|
midi_paths = list(Path('input').glob('**/*.mid')) + list(Path('input').glob('**/*.midi')) |
|
|
|
""" list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) """ |
|
|
|
|
|
kwargs_dataset = {"min_seq_len": 10, "max_seq_len": 1024, "tokenizer": tokenizer} |
|
dataset_test = DatasetTok(midi_paths, **kwargs_dataset) |
|
collator = DataCollator( |
|
tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"] |
|
) |
|
|
|
|
|
model = MistralForCausalLM.from_pretrained("./runs") |
|
|
|
collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True) |
|
|
|
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True) |
|
generation_config = GenerationConfig( |
|
max_new_tokens=512, |
|
num_beams=1, |
|
do_sample=True, |
|
temperature=0.9, |
|
top_k=15, |
|
top_p=0.95, |
|
epsilon_cutoff=3e-4, |
|
eta_cutoff=1e-3, |
|
) |
|
|
|
|
|
|
|
collator.pad_on_left = True |
|
collator.eos_token = None |
|
dataloader_test = DataLoader(dataset_test, batch_size=1, collate_fn=collator) |
|
model.eval() |
|
count = 0 |
|
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): |
|
res = model.generate( |
|
inputs=batch["input_ids"].to(model.device), |
|
attention_mask=batch["attention_mask"].to(model.device), |
|
generation_config=generation_config) |
|
|
|
|
|
for prompt, continuation in zip(batch["input_ids"], res): |
|
generated = continuation[len(prompt):] |
|
midi = tokenizer.tokens_to_midi([deepcopy(generated.tolist())]) |
|
tokens = [generated, prompt, continuation] |
|
tokens = [seq.tolist() for seq in tokens] |
|
for tok_seq in tokens[1:]: |
|
_midi = tokenizer.tokens_to_midi([deepcopy(tok_seq)]) |
|
midi.tracks.append(_midi.tracks[0]) |
|
midi.tracks[0].name = f'Continuation of original sample ({len(generated)} tokens)' |
|
midi.tracks[1].name = f'Original sample ({len(prompt)} tokens)' |
|
midi.tracks[2].name = f'Original sample and continuation' |
|
midi.dump_midi(gen_results_path / f'{count}.mid') |
|
tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json') |
|
|
|
count += 1 |