m41w4r3.exe commited on
Commit
7edf1ce
1 Parent(s): 2ec0615

minor initial fixes

Browse files
Files changed (5) hide show
  1. app.py +0 -1
  2. familizer.py +137 -0
  3. generate.py +1 -4
  4. generation_utils.py +141 -0
  5. requirements.txt +16 -1
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  from load import LoadModel
3
  from generate import GenerateMidiText
4
  from constants import INSTRUMENT_CLASSES
5
- from encoder import MIDIEncoder
6
  from decoder import TextDecoder
7
  from utils import get_miditok, index_has_substring
8
  from playback import get_music
 
2
  from load import LoadModel
3
  from generate import GenerateMidiText
4
  from constants import INSTRUMENT_CLASSES
 
5
  from decoder import TextDecoder
6
  from utils import get_miditok, index_has_substring
7
  from playback import get_music
familizer.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from joblib import Parallel, delayed
3
+ from pathlib import Path
4
+ from constants import INSTRUMENT_CLASSES, INSTRUMENT_TRANSFER_CLASSES
5
+ from utils import get_files, timeit, FileCompressor
6
+
7
+
8
+ class Familizer:
9
+ def __init__(self, n_jobs=-1, arbitrary=False):
10
+ self.n_jobs = n_jobs
11
+ self.reverse_family(arbitrary)
12
+
13
+ def get_family_number(self, program_number):
14
+ """
15
+ Given a MIDI instrument number, return its associated instrument family number.
16
+ """
17
+ for instrument_class in INSTRUMENT_CLASSES:
18
+ if program_number in instrument_class["program_range"]:
19
+ return instrument_class["family_number"]
20
+
21
+ def reverse_family(self, arbitrary):
22
+ """
23
+ Create a dictionary of family numbers to randomly assigned program numbers.
24
+ This is used to reverse the family number tokens back to program number tokens.
25
+ """
26
+
27
+ if arbitrary is True:
28
+ int_class = INSTRUMENT_TRANSFER_CLASSES
29
+ else:
30
+ int_class = INSTRUMENT_CLASSES
31
+
32
+ self.reference_programs = {}
33
+ for family in int_class:
34
+ self.reference_programs[family["family_number"]] = random.choice(
35
+ family["program_range"]
36
+ )
37
+
38
+ def get_program_number(self, family_number):
39
+ """
40
+ Given given a family number return a random program number in the respective program_range.
41
+ This is the reverse operation of get_family_number.
42
+ """
43
+ assert family_number in self.reference_programs
44
+ return self.reference_programs[family_number]
45
+
46
+ # Replace instruments in text files
47
+ def replace_instrument_token(self, token):
48
+ """
49
+ Given a MIDI program number in a word token, replace it with the family or program
50
+ number token depending on the operation.
51
+ e.g. INST=86 -> INST=10
52
+ """
53
+ inst_number = int(token.split("=")[1])
54
+ if self.operation == "family":
55
+ return "INST=" + str(self.get_family_number(inst_number))
56
+ elif self.operation == "program":
57
+ return "INST=" + str(self.get_program_number(inst_number))
58
+
59
+ def replace_instrument_in_text(self, text):
60
+ """Given a text piece, replace all instrument tokens with family number tokens."""
61
+ return " ".join(
62
+ [
63
+ self.replace_instrument_token(token)
64
+ if token.startswith("INST=") and not token == "INST=DRUMS"
65
+ else token
66
+ for token in text.split(" ")
67
+ ]
68
+ )
69
+
70
+ def replace_instruments_in_file(self, file):
71
+ """Given a text file, replace all instrument tokens with family number tokens."""
72
+ text = file.read_text()
73
+ file.write_text(self.replace_instrument_in_text(text))
74
+
75
+ @timeit
76
+ def replace_instruments(self):
77
+ """
78
+ Given a directory of text files:
79
+ Replace all instrument tokens with family number tokens.
80
+ """
81
+ files = get_files(self.output_directory, extension="txt")
82
+ Parallel(n_jobs=self.n_jobs)(
83
+ delayed(self.replace_instruments_in_file)(file) for file in files
84
+ )
85
+
86
+ def replace_tokens(self, input_directory, output_directory, operation):
87
+ """
88
+ Given a directory and an operation, perform the operation on all text files in the directory.
89
+ operation can be either 'family' or 'program'.
90
+ """
91
+ self.input_directory = input_directory
92
+ self.output_directory = output_directory
93
+ self.operation = operation
94
+
95
+ # Uncompress files, replace tokens, compress files
96
+ fc = FileCompressor(self.input_directory, self.output_directory, self.n_jobs)
97
+ fc.unzip()
98
+ self.replace_instruments()
99
+ fc.zip()
100
+ print(self.operation + " complete.")
101
+
102
+ def to_family(self, input_directory, output_directory):
103
+ """
104
+ Given a directory containing zip files, replace all instrument tokens with
105
+ family number tokens. The output is a directory of zip files.
106
+ """
107
+ self.replace_tokens(input_directory, output_directory, "family")
108
+
109
+ def to_program(self, input_directory, output_directory):
110
+ """
111
+ Given a directory containing zip files, replace all instrument tokens with
112
+ program number tokens. The output is a directory of zip files.
113
+ """
114
+ self.replace_tokens(input_directory, output_directory, "program")
115
+
116
+
117
+ if __name__ == "__main__":
118
+
119
+ # Choose number of jobs for parallel processing
120
+ n_jobs = -1
121
+
122
+ # Instantiate Familizer
123
+ familizer = Familizer(n_jobs)
124
+
125
+ # Choose directory to process for program
126
+ input_directory = Path("midi/dataset/first_selection/validate").resolve() # fmt: skip
127
+ output_directory = input_directory / "family"
128
+
129
+ # familize files
130
+ familizer.to_family(input_directory, output_directory)
131
+
132
+ # Choose directory to process for family
133
+ # input_directory = Path("../data/music_picks/encoded_samples/validate/family").resolve() # fmt: skip
134
+ # output_directory = input_directory.parent / "program"
135
+
136
+ # # programize files
137
+ # familizer.to_program(input_directory, output_directory)
generate.py CHANGED
@@ -1,11 +1,8 @@
1
  from generation_utils import *
2
  from utils import WriteTextMidiToFile, get_miditok
3
  from load import LoadModel
4
- from constants import INSTRUMENT_CLASSES
5
-
6
- ## import for execution
7
  from decoder import TextDecoder
8
- from playback import get_music, show_piano_roll
9
 
10
 
11
  class GenerateMidiText:
 
1
  from generation_utils import *
2
  from utils import WriteTextMidiToFile, get_miditok
3
  from load import LoadModel
 
 
 
4
  from decoder import TextDecoder
5
+ from playback import get_music
6
 
7
 
8
  class GenerateMidiText:
generation_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib
5
+ from constants import INSTRUMENT_CLASSES
6
+
7
+ # matplotlib settings
8
+ matplotlib.use("Agg") # for server
9
+ matplotlib.rcParams["xtick.major.size"] = 0
10
+ matplotlib.rcParams["ytick.major.size"] = 0
11
+ matplotlib.rcParams["axes.facecolor"] = "grey"
12
+ matplotlib.rcParams["axes.edgecolor"] = "none"
13
+
14
+
15
+ def define_generation_dir(model_repo_path):
16
+ #### to remove later ####
17
+ if model_repo_path == "models/model_2048_fake_wholedataset":
18
+ model_repo_path = "misnaej/the-jam-machine"
19
+ #### to remove later ####
20
+ generated_sequence_files_path = f"midi/generated/{model_repo_path}"
21
+ if not os.path.exists(generated_sequence_files_path):
22
+ os.makedirs(generated_sequence_files_path)
23
+ return generated_sequence_files_path
24
+
25
+
26
+ def bar_count_check(sequence, n_bars):
27
+ """check if the sequence contains the right number of bars"""
28
+ sequence = sequence.split(" ")
29
+ # find occurences of "BAR_END" in a "sequence"
30
+ # I don't check for "BAR_START" because it is not always included in "sequence"
31
+ # e.g. BAR_START is included the prompt when generating one more bar
32
+ bar_count = 0
33
+ for seq in sequence:
34
+ if seq == "BAR_END":
35
+ bar_count += 1
36
+ bar_count_matches = bar_count == n_bars
37
+ if not bar_count_matches:
38
+ print(f"Bar count is {bar_count} - but should be {n_bars}")
39
+ return bar_count_matches, bar_count
40
+
41
+
42
+ def print_inst_classes(INSTRUMENT_CLASSES):
43
+ """Print the instrument classes"""
44
+ for classe in INSTRUMENT_CLASSES:
45
+ print(f"{classe}")
46
+
47
+
48
+ def check_if_prompt_inst_in_tokenizer_vocab(tokenizer, inst_prompt_list):
49
+ """Check if the prompt instrument are in the tokenizer vocab"""
50
+ for inst in inst_prompt_list:
51
+ if f"INST={inst}" not in tokenizer.vocab:
52
+ instruments_in_dataset = np.sort(
53
+ [tok.split("=")[-1] for tok in tokenizer.vocab if "INST" in tok]
54
+ )
55
+ print_inst_classes(INSTRUMENT_CLASSES)
56
+ raise ValueError(
57
+ f"""The instrument {inst} is not in the tokenizer vocabulary.
58
+ Available Instruments: {instruments_in_dataset}"""
59
+ )
60
+
61
+
62
+ def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
63
+ """Forcing the generated sequence to have the expected length
64
+ expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""
65
+
66
+ if bar_count - expected_length > 0: # Cut the sequence if too long
67
+ full_piece = ""
68
+ splited = generated.split("BAR_END ")
69
+ for count, spl in enumerate(splited):
70
+ if count < expected_length:
71
+ full_piece += spl + "BAR_END "
72
+
73
+ full_piece += "TRACK_END "
74
+ full_piece = input_prompt + full_piece
75
+ print(f"Generated sequence trunkated at {expected_length} bars")
76
+ bar_count_checks = True
77
+
78
+ elif bar_count - expected_length < 0: # Do nothing it the sequence if too short
79
+ full_piece = input_prompt + generated
80
+ bar_count_checks = False
81
+ print(f"--- Generated sequence is too short - Force Regeration ---")
82
+
83
+ return full_piece, bar_count_checks
84
+
85
+
86
+ def get_max_time(inst_midi):
87
+ max_time = 0
88
+ for inst in inst_midi.instruments:
89
+ max_time = max(max_time, inst.get_end_time())
90
+ return max_time
91
+
92
+
93
+ def plot_piano_roll(inst_midi):
94
+ piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
95
+ piano_roll_fig.tight_layout()
96
+ piano_roll_fig.patch.set_alpha(0.1)
97
+ inst_count = 0
98
+ beats_per_bar = 4
99
+ sec_per_beat = 0.5
100
+ next_beat = max(inst_midi.get_beats()) + np.diff(inst_midi.get_beats())[0]
101
+ bars_time = np.append(inst_midi.get_beats(), (next_beat))[::beats_per_bar].astype(
102
+ int
103
+ )
104
+ for inst in inst_midi.instruments:
105
+ inst_count += 1
106
+ plt.subplot(len(inst_midi.instruments), 1, inst_count)
107
+
108
+ for bar in bars_time:
109
+ plt.axvline(bar, color="grey", linewidth=0.5)
110
+ octaves = np.arange(0, 128, 12)
111
+ for octave in octaves:
112
+ plt.axhline(octave, color="grey", linewidth=0.5)
113
+ plt.yticks(octaves, visible=False)
114
+
115
+ p_midi_note_list = inst.notes
116
+ note_time = []
117
+ note_pitch = []
118
+ for note in p_midi_note_list:
119
+ note_time.append([note.start, note.end])
120
+ note_pitch.append([note.pitch, note.pitch])
121
+
122
+ plt.plot(
123
+ np.array(note_time).T,
124
+ np.array(note_pitch).T,
125
+ color="purple",
126
+ linewidth=3,
127
+ solid_capstyle="butt",
128
+ )
129
+ plt.ylim(0, 128)
130
+ xticks = np.array(bars_time)[:-1]
131
+ plt.tight_layout()
132
+ plt.xlim(min(bars_time), max(bars_time))
133
+ # plt.xlabel("bars")
134
+ plt.xticks(
135
+ xticks + 0.5 * beats_per_bar * sec_per_beat,
136
+ labels=xticks.argsort() + 1,
137
+ visible=False,
138
+ )
139
+ plt.title(inst.name, fontsize=10, color="white")
140
+
141
+ return piano_roll_fig
requirements.txt CHANGED
@@ -2,4 +2,19 @@ gradio
2
  matplotlib
3
  sys
4
  matplotlib
5
- numpy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  matplotlib
3
  sys
4
  matplotlib
5
+ numpy
6
+ joblib
7
+ pathlib
8
+ random
9
+ transformers
10
+ os
11
+ miditok
12
+ librosa
13
+ pretty_midi
14
+ pydub
15
+ shutil
16
+ scipy
17
+ zipfile
18
+ time
19
+ json
20
+ datetime