MIDI_Transformer_Mistral / generate_on_one_track.py
sunsetsobserver's picture
Add generate only continuation
ae3131d
from copy import deepcopy
from pathlib import Path
from random import shuffle
from torch import Tensor, argmax
from torch.utils.data import DataLoader
from torch.cuda import is_available as cuda_available, is_bf16_supported
from torch.backends.mps import is_available as mps_available
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, MistralForCausalLM
from transformers.trainer_utils import set_seed
from evaluate import load as load_metric
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetTok, DataCollator
from tqdm import tqdm
# Our tokenizer's configuration
PITCH_RANGE = (21, 109)
BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1}
NUM_VELOCITIES = 24
SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"]
USE_CHORDS = False
USE_RESTS = False
USE_TEMPOS = True
USE_TIME_SIGNATURE = False
USE_PROGRAMS = False
NUM_TEMPOS = 32
TEMPO_RANGE = (50, 200) # (min_tempo, max_tempo)
TOKENIZER_PARAMS = {
"pitch_range": PITCH_RANGE,
"beat_res": BEAT_RES,
"num_velocities": NUM_VELOCITIES,
"special_tokens": SPECIAL_TOKENS,
"use_chords": USE_CHORDS,
"use_rests": USE_RESTS,
"use_tempos": USE_TEMPOS,
"use_time_signatures": USE_TIME_SIGNATURE,
"use_programs": USE_PROGRAMS,
"num_tempos": NUM_TEMPOS,
"tempo_range": TEMPO_RANGE,
}
config = TokenizerConfig(**TOKENIZER_PARAMS)
# Seed
set_seed(777)
# Creates the tokenizer
tokenizer = REMI.from_pretrained("sunsetsobserver/MIDI")
midi_paths = list(Path('input').glob('**/*.mid')) + list(Path('input').glob('**/*.midi'))
""" list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) """
# Loads tokens and create data collator
kwargs_dataset = {"min_seq_len": 10, "max_seq_len": 1024, "tokenizer": tokenizer}
dataset_test = DatasetTok(midi_paths, **kwargs_dataset)
collator = DataCollator(
tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"]
)
# Creates model using the correct configuration
model = MistralForCausalLM.from_pretrained("./runs")
collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True)
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
generation_config = GenerationConfig(
max_new_tokens=512, # extends samples by 512 tokens
num_beams=1, # no beam search
do_sample=True, # but sample instead
temperature=0.9,
top_k=15,
top_p=0.95,
epsilon_cutoff=3e-4,
eta_cutoff=1e-3,
)
# Here the sequences are padded to the left, so that the last token along the time dimension
# is always the last token of each seq, allowing to efficiently generate by batch
collator.pad_on_left = True
collator.eos_token = None
dataloader_test = DataLoader(dataset_test, batch_size=1, collate_fn=collator)
model.eval()
count = 0
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)
res = model.generate(
inputs=batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
generation_config=generation_config) # (N,T)
# Saves the generated music, as MIDI files and tokens (json)
for prompt, continuation in zip(batch["input_ids"], res):
# Generate the MIDI for the entire sequence (prompt + continuation)
midi = tokenizer.tokens_to_midi([deepcopy(continuation.tolist())])
# Set the track name to indicate it includes both the original and the continuation
midi.tracks[0].name = f'Original sample and continuation ({len(continuation)} tokens)'
# Dump the MIDI file for the combined prompt and continuation
midi.dump_midi(gen_results_path / f'{count}.mid')
# Optionally, save the tokens for the combined sequence
tokens = [continuation.tolist()] # This time, only saving the combined sequence
tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')
count += 1