File size: 4,022 Bytes
ae3131d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from copy import deepcopy
from pathlib import Path
from random import shuffle

from torch import Tensor, argmax
from torch.utils.data import DataLoader
from torch.cuda import is_available as cuda_available, is_bf16_supported
from torch.backends.mps import is_available as mps_available
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, MistralForCausalLM
from transformers.trainer_utils import set_seed
from evaluate import load as load_metric
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetTok, DataCollator
from tqdm import tqdm

# Our tokenizer's configuration
PITCH_RANGE = (21, 109)
BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1}
NUM_VELOCITIES = 24
SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"]
USE_CHORDS = False
USE_RESTS = False
USE_TEMPOS = True
USE_TIME_SIGNATURE = False
USE_PROGRAMS = False
NUM_TEMPOS = 32
TEMPO_RANGE = (50, 200)  # (min_tempo, max_tempo)
TOKENIZER_PARAMS = {
    "pitch_range": PITCH_RANGE,
    "beat_res": BEAT_RES,
    "num_velocities": NUM_VELOCITIES,
    "special_tokens": SPECIAL_TOKENS,
    "use_chords": USE_CHORDS,
    "use_rests": USE_RESTS,
    "use_tempos": USE_TEMPOS,
    "use_time_signatures": USE_TIME_SIGNATURE,
    "use_programs": USE_PROGRAMS,
    "num_tempos": NUM_TEMPOS,
    "tempo_range": TEMPO_RANGE,
}
config = TokenizerConfig(**TOKENIZER_PARAMS)

# Seed
set_seed(777)

# Creates the tokenizer
tokenizer = REMI.from_pretrained("sunsetsobserver/MIDI")

midi_paths = list(Path('input').glob('**/*.mid')) + list(Path('input').glob('**/*.midi'))

""" list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) """

# Loads tokens and create data collator
kwargs_dataset = {"min_seq_len": 10, "max_seq_len": 1024, "tokenizer": tokenizer}
dataset_test = DatasetTok(midi_paths, **kwargs_dataset)
collator = DataCollator(
    tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"]
)

# Creates model using the correct configuration
model = MistralForCausalLM.from_pretrained("./runs")

collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True)

(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
generation_config = GenerationConfig(
    max_new_tokens=512,  # extends samples by 512 tokens
    num_beams=1,        # no beam search
    do_sample=True,     # but sample instead
    temperature=0.9,
    top_k=15,
    top_p=0.95,
    epsilon_cutoff=3e-4,
    eta_cutoff=1e-3,
)

# Here the sequences are padded to the left, so that the last token along the time dimension
# is always the last token of each seq, allowing to efficiently generate by batch
collator.pad_on_left = True
collator.eos_token = None
dataloader_test = DataLoader(dataset_test, batch_size=1, collate_fn=collator)
model.eval()
count = 0
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'):  # (N,T)
    res = model.generate(
        inputs=batch["input_ids"].to(model.device),
        attention_mask=batch["attention_mask"].to(model.device),
        generation_config=generation_config)  # (N,T)

    # Saves the generated music, as MIDI files and tokens (json)
    for prompt, continuation in zip(batch["input_ids"], res):
        # Generate the MIDI for the entire sequence (prompt + continuation)
        midi = tokenizer.tokens_to_midi([deepcopy(continuation.tolist())])

        # Set the track name to indicate it includes both the original and the continuation
        midi.tracks[0].name = f'Original sample and continuation ({len(continuation)} tokens)'

        # Dump the MIDI file for the combined prompt and continuation
        midi.dump_midi(gen_results_path / f'{count}.mid')

        # Optionally, save the tokens for the combined sequence
        tokens = [continuation.tolist()]  # This time, only saving the combined sequence
        tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')

        count += 1