sunsetsobserver commited on
Commit
ae3131d
1 Parent(s): 2171a21

Add generate only continuation

Browse files
Files changed (3) hide show
  1. gen_res/0.json +0 -1
  2. gen_res/0.mid +0 -0
  3. generate_on_one_track.py +105 -0
gen_res/0.json DELETED
@@ -1 +0,0 @@
1
- {"ids": [[155, 32, 112, 122, 158, 37, 112, 121, 161, 41, 597, 163, 25, 469, 166, 36, 113, 122, 184, 25, 113, 119, 141, 25, 441, 145, 25, 113, 120, 147, 25, 111, 126, 29, 111, 130, 149, 27, 597, 154, 30, 221, 157, 25, 111, 124, 161, 25, 112, 121, 166, 29, 822, 308, 32, 112, 131, 149, 30, 112, 125, 159, 32, 180, 160, 29, 441, 166, 29, 322, 190, 32, 111, 123, 143, 25, 110, 130, 152, 37, 221, 156, 25, 113, 122, 158, 25, 113, 119, 161, 13, 322, 166, 30, 112, 121, 168, 25, 112, 122, 184, 37, 111, 130, 37, 112, 129, 25, 822, 143, 25, 192, 144, 25, 180, 148, 25, 378, 150, 25, 378, 153, 25, 192, 156, 25, 112, 123, 25, 279, 161, 25, 378, 165, 25, 112, 125, 184, 25, 112, 122, 25, 112, 124, 25, 113, 119, 190, 25, 226, 139, 25, 180, 140, 25, 180, 144, 25, 226, 25, 192, 151, 25, 192, 25, 220, 152, 25, 220, 13, 180, 157, 25, 220, 32, 378, 13, 178, 158, 25, 220, 163, 25, 112, 122, 25, 220, 25, 196, 164, 25, 185, 167, 25, 220, 25, 114, 117, 25, 180, 184, 25, 180, 37, 378, 141, 25, 378, 143, 25, 220, 25, 180, 149, 25, 192, 25, 220, 149, 25, 220, 49, 220, 151, 25, 220, 37, 180, 154, 25, 192, 25, 192, 155, 25, 192, 158, 25, 113, 119, 25, 378, 13, 175, 161, 25, 192, 25, 178, 13, 171, 162, 25, 185, 167, 25, 378, 25, 220, 25, 220, 206, 25, 220, 25, 178, 141, 25, 178, 25, 178, 144, 25, 180, 25, 178, 146, 25, 175, 25, 191, 148, 25, 177, 25, 175, 149, 25, 175, 25, 178, 151, 25, 178, 152, 25, 191, 25, 192, 154, 25, 180, 25, 175, 25, 175, 155, 25, 178, 157, 25, 192, 25, 180, 25, 180, 25, 180, 160, 25, 180, 25, 192, 25, 180, 161, 25, 178, 25, 180, 163, 25, 378, 25, 378, 25, 175, 25, 185, 168, 25, 220, 25, 220, 25, 180, 32, 191, 144, 25, 220, 25, 192, 25, 181, 25, 181, 146, 13, 173, 24, 175, 149, 25, 185, 27, 181, 19, 175, 151, 24, 180, 25, 175, 32, 178, 17, 174, 155, 24, 180, 20, 178, 25, 178, 27, 178, 27, 180, 154, 31, 192, 27, 180, 27, 180, 20, 178, 24, 175, 144, 24, 192, 20, 180, 27, 180, 27, 178, 145, 31, 180, 19, 178, 20, 178, 31, 178, 148, 31, 378, 19, 175, 24, 174, 32, 192, 151, 31, 180, 19, 180, 31, 192, 31, 180, 154, 27, 178, 32, 378, 23, 178, 25, 180, 155, 27, 192, 26, 180, 19, 175, 19, 220, 156, 19, 175, 19, 178, 157, 31, 378, 43, 192, 31, 192, 19, 180, 159, 25, 192, 31, 192, 19, 178, 31, 180, 160, 19, 178, 43, 192, 27, 178, 161, 26, 180, 26, 180, 162, 19, 180, 31, 180, 19, 192, 31, 180, 163, 31, 192, 19, 180, 19, 175, 31, 178, 19, 175, 31, 220, 165], [2, 184, 169, 141, 20, 113, 118, 22, 113, 118, 29, 113, 118, 38, 113, 118, 143, 17, 441, 24, 919, 33, 919, 40, 919, 147, 21, 113, 118, 28, 114, 118, 30, 113, 118, 39, 113, 118, 149, 16, 112, 128, 18, 111, 128, 25, 111, 128, 34, 111, 128, 165, 36, 219, 184, 13, 187, 139, 20, 179, 141, 27, 188, 143, 34, 196, 153, 20, 109, 132, 28, 211, 33, 211, 156, 30, 205, 35, 205, 158, 32, 205, 37, 211, 161, 27, 222, 34, 227, 165, 25, 222, 29, 222, 36, 235, 184, 26, 111, 128, 31, 111, 128, 153, 33, 407], [2, 184, 169, 141, 20, 113, 118, 22, 113, 118, 29, 113, 118, 38, 113, 118, 143, 17, 441, 24, 919, 33, 919, 40, 919, 147, 21, 113, 118, 28, 114, 118, 30, 113, 118, 39, 113, 118, 149, 16, 112, 128, 18, 111, 128, 25, 111, 128, 34, 111, 128, 165, 36, 219, 184, 13, 187, 139, 20, 179, 141, 27, 188, 143, 34, 196, 153, 20, 109, 132, 28, 211, 33, 211, 156, 30, 205, 35, 205, 158, 32, 205, 37, 211, 161, 27, 222, 34, 227, 165, 25, 222, 29, 222, 36, 235, 184, 26, 111, 128, 31, 111, 128, 153, 33, 407, 155, 32, 112, 122, 158, 37, 112, 121, 161, 41, 597, 163, 25, 469, 166, 36, 113, 122, 184, 25, 113, 119, 141, 25, 441, 145, 25, 113, 120, 147, 25, 111, 126, 29, 111, 130, 149, 27, 597, 154, 30, 221, 157, 25, 111, 124, 161, 25, 112, 121, 166, 29, 822, 308, 32, 112, 131, 149, 30, 112, 125, 159, 32, 180, 160, 29, 441, 166, 29, 322, 190, 32, 111, 123, 143, 25, 110, 130, 152, 37, 221, 156, 25, 113, 122, 158, 25, 113, 119, 161, 13, 322, 166, 30, 112, 121, 168, 25, 112, 122, 184, 37, 111, 130, 37, 112, 129, 25, 822, 143, 25, 192, 144, 25, 180, 148, 25, 378, 150, 25, 378, 153, 25, 192, 156, 25, 112, 123, 25, 279, 161, 25, 378, 165, 25, 112, 125, 184, 25, 112, 122, 25, 112, 124, 25, 113, 119, 190, 25, 226, 139, 25, 180, 140, 25, 180, 144, 25, 226, 25, 192, 151, 25, 192, 25, 220, 152, 25, 220, 13, 180, 157, 25, 220, 32, 378, 13, 178, 158, 25, 220, 163, 25, 112, 122, 25, 220, 25, 196, 164, 25, 185, 167, 25, 220, 25, 114, 117, 25, 180, 184, 25, 180, 37, 378, 141, 25, 378, 143, 25, 220, 25, 180, 149, 25, 192, 25, 220, 149, 25, 220, 49, 220, 151, 25, 220, 37, 180, 154, 25, 192, 25, 192, 155, 25, 192, 158, 25, 113, 119, 25, 378, 13, 175, 161, 25, 192, 25, 178, 13, 171, 162, 25, 185, 167, 25, 378, 25, 220, 25, 220, 206, 25, 220, 25, 178, 141, 25, 178, 25, 178, 144, 25, 180, 25, 178, 146, 25, 175, 25, 191, 148, 25, 177, 25, 175, 149, 25, 175, 25, 178, 151, 25, 178, 152, 25, 191, 25, 192, 154, 25, 180, 25, 175, 25, 175, 155, 25, 178, 157, 25, 192, 25, 180, 25, 180, 25, 180, 160, 25, 180, 25, 192, 25, 180, 161, 25, 178, 25, 180, 163, 25, 378, 25, 378, 25, 175, 25, 185, 168, 25, 220, 25, 220, 25, 180, 32, 191, 144, 25, 220, 25, 192, 25, 181, 25, 181, 146, 13, 173, 24, 175, 149, 25, 185, 27, 181, 19, 175, 151, 24, 180, 25, 175, 32, 178, 17, 174, 155, 24, 180, 20, 178, 25, 178, 27, 178, 27, 180, 154, 31, 192, 27, 180, 27, 180, 20, 178, 24, 175, 144, 24, 192, 20, 180, 27, 180, 27, 178, 145, 31, 180, 19, 178, 20, 178, 31, 178, 148, 31, 378, 19, 175, 24, 174, 32, 192, 151, 31, 180, 19, 180, 31, 192, 31, 180, 154, 27, 178, 32, 378, 23, 178, 25, 180, 155, 27, 192, 26, 180, 19, 175, 19, 220, 156, 19, 175, 19, 178, 157, 31, 378, 43, 192, 31, 192, 19, 180, 159, 25, 192, 31, 192, 19, 178, 31, 180, 160, 19, 178, 43, 192, 27, 178, 161, 26, 180, 26, 180, 162, 19, 180, 31, 180, 19, 192, 31, 180, 163, 31, 192, 19, 180, 19, 175, 31, 178, 19, 175, 31, 220, 165]]}
 
 
gen_res/0.mid DELETED
Binary file (3.32 kB)
 
generate_on_one_track.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from pathlib import Path
3
+ from random import shuffle
4
+
5
+ from torch import Tensor, argmax
6
+ from torch.utils.data import DataLoader
7
+ from torch.cuda import is_available as cuda_available, is_bf16_supported
8
+ from torch.backends.mps import is_available as mps_available
9
+ from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, MistralForCausalLM
10
+ from transformers.trainer_utils import set_seed
11
+ from evaluate import load as load_metric
12
+ from miditok import REMI, TokenizerConfig
13
+ from miditok.pytorch_data import DatasetTok, DataCollator
14
+ from tqdm import tqdm
15
+
16
+ # Our tokenizer's configuration
17
+ PITCH_RANGE = (21, 109)
18
+ BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1}
19
+ NUM_VELOCITIES = 24
20
+ SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"]
21
+ USE_CHORDS = False
22
+ USE_RESTS = False
23
+ USE_TEMPOS = True
24
+ USE_TIME_SIGNATURE = False
25
+ USE_PROGRAMS = False
26
+ NUM_TEMPOS = 32
27
+ TEMPO_RANGE = (50, 200) # (min_tempo, max_tempo)
28
+ TOKENIZER_PARAMS = {
29
+ "pitch_range": PITCH_RANGE,
30
+ "beat_res": BEAT_RES,
31
+ "num_velocities": NUM_VELOCITIES,
32
+ "special_tokens": SPECIAL_TOKENS,
33
+ "use_chords": USE_CHORDS,
34
+ "use_rests": USE_RESTS,
35
+ "use_tempos": USE_TEMPOS,
36
+ "use_time_signatures": USE_TIME_SIGNATURE,
37
+ "use_programs": USE_PROGRAMS,
38
+ "num_tempos": NUM_TEMPOS,
39
+ "tempo_range": TEMPO_RANGE,
40
+ }
41
+ config = TokenizerConfig(**TOKENIZER_PARAMS)
42
+
43
+ # Seed
44
+ set_seed(777)
45
+
46
+ # Creates the tokenizer
47
+ tokenizer = REMI.from_pretrained("sunsetsobserver/MIDI")
48
+
49
+ midi_paths = list(Path('input').glob('**/*.mid')) + list(Path('input').glob('**/*.midi'))
50
+
51
+ """ list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) """
52
+
53
+ # Loads tokens and create data collator
54
+ kwargs_dataset = {"min_seq_len": 10, "max_seq_len": 1024, "tokenizer": tokenizer}
55
+ dataset_test = DatasetTok(midi_paths, **kwargs_dataset)
56
+ collator = DataCollator(
57
+ tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"]
58
+ )
59
+
60
+ # Creates model using the correct configuration
61
+ model = MistralForCausalLM.from_pretrained("./runs")
62
+
63
+ collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True)
64
+
65
+ (gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
66
+ generation_config = GenerationConfig(
67
+ max_new_tokens=512, # extends samples by 512 tokens
68
+ num_beams=1, # no beam search
69
+ do_sample=True, # but sample instead
70
+ temperature=0.9,
71
+ top_k=15,
72
+ top_p=0.95,
73
+ epsilon_cutoff=3e-4,
74
+ eta_cutoff=1e-3,
75
+ )
76
+
77
+ # Here the sequences are padded to the left, so that the last token along the time dimension
78
+ # is always the last token of each seq, allowing to efficiently generate by batch
79
+ collator.pad_on_left = True
80
+ collator.eos_token = None
81
+ dataloader_test = DataLoader(dataset_test, batch_size=1, collate_fn=collator)
82
+ model.eval()
83
+ count = 0
84
+ for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): # (N,T)
85
+ res = model.generate(
86
+ inputs=batch["input_ids"].to(model.device),
87
+ attention_mask=batch["attention_mask"].to(model.device),
88
+ generation_config=generation_config) # (N,T)
89
+
90
+ # Saves the generated music, as MIDI files and tokens (json)
91
+ for prompt, continuation in zip(batch["input_ids"], res):
92
+ # Generate the MIDI for the entire sequence (prompt + continuation)
93
+ midi = tokenizer.tokens_to_midi([deepcopy(continuation.tolist())])
94
+
95
+ # Set the track name to indicate it includes both the original and the continuation
96
+ midi.tracks[0].name = f'Original sample and continuation ({len(continuation)} tokens)'
97
+
98
+ # Dump the MIDI file for the combined prompt and continuation
99
+ midi.dump_midi(gen_results_path / f'{count}.mid')
100
+
101
+ # Optionally, save the tokens for the combined sequence
102
+ tokens = [continuation.tolist()] # This time, only saving the combined sequence
103
+ tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')
104
+
105
+ count += 1