sunsetsobserver's picture
working generator
2171a21
raw
history blame
No virus
4.1 kB
from copy import deepcopy
from pathlib import Path
from random import shuffle
from torch import Tensor, argmax
from torch.utils.data import DataLoader
from torch.cuda import is_available as cuda_available, is_bf16_supported
from torch.backends.mps import is_available as mps_available
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, MistralForCausalLM
from transformers.trainer_utils import set_seed
from evaluate import load as load_metric
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetTok, DataCollator
from tqdm import tqdm
# Our tokenizer's configuration
PITCH_RANGE = (21, 109)
BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1}
NUM_VELOCITIES = 24
SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"]
USE_CHORDS = False
USE_RESTS = False
USE_TEMPOS = True
USE_TIME_SIGNATURE = False
USE_PROGRAMS = False
NUM_TEMPOS = 32
TEMPO_RANGE = (50, 200) # (min_tempo, max_tempo)
TOKENIZER_PARAMS = {
"pitch_range": PITCH_RANGE,
"beat_res": BEAT_RES,
"num_velocities": NUM_VELOCITIES,
"special_tokens": SPECIAL_TOKENS,
"use_chords": USE_CHORDS,
"use_rests": USE_RESTS,
"use_tempos": USE_TEMPOS,
"use_time_signatures": USE_TIME_SIGNATURE,
"use_programs": USE_PROGRAMS,
"num_tempos": NUM_TEMPOS,
"tempo_range": TEMPO_RANGE,
}
config = TokenizerConfig(**TOKENIZER_PARAMS)
# Seed
set_seed(777)
# Creates the tokenizer
tokenizer = REMI.from_pretrained("sunsetsobserver/MIDI")
midi_paths = list(Path('input').glob('**/*.mid')) + list(Path('input').glob('**/*.midi'))
""" list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) """
# Loads tokens and create data collator
kwargs_dataset = {"min_seq_len": 10, "max_seq_len": 1024, "tokenizer": tokenizer}
dataset_test = DatasetTok(midi_paths, **kwargs_dataset)
collator = DataCollator(
tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"]
)
# Creates model using the correct configuration
model = MistralForCausalLM.from_pretrained("./runs")
collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True)
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
generation_config = GenerationConfig(
max_new_tokens=512, # extends samples by 512 tokens
num_beams=1, # no beam search
do_sample=True, # but sample instead
temperature=0.9,
top_k=15,
top_p=0.95,
epsilon_cutoff=3e-4,
eta_cutoff=1e-3,
)
# Here the sequences are padded to the left, so that the last token along the time dimension
# is always the last token of each seq, allowing to efficiently generate by batch
collator.pad_on_left = True
collator.eos_token = None
dataloader_test = DataLoader(dataset_test, batch_size=1, collate_fn=collator)
model.eval()
count = 0
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)
res = model.generate(
inputs=batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
generation_config=generation_config) # (N,T)
# Saves the generated music, as MIDI files and tokens (json)
for prompt, continuation in zip(batch["input_ids"], res):
generated = continuation[len(prompt):]
midi = tokenizer.tokens_to_midi([deepcopy(generated.tolist())])
tokens = [generated, prompt, continuation] # list compr. as seqs of dif. lengths
tokens = [seq.tolist() for seq in tokens]
for tok_seq in tokens[1:]:
_midi = tokenizer.tokens_to_midi([deepcopy(tok_seq)])
midi.tracks.append(_midi.tracks[0])
midi.tracks[0].name = f'Continuation of original sample ({len(generated)} tokens)'
midi.tracks[1].name = f'Original sample ({len(prompt)} tokens)'
midi.tracks[2].name = f'Original sample and continuation'
midi.dump_midi(gen_results_path / f'{count}.mid')
tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')
count += 1