m41w4r3.exe
commited on
Commit
•
7edf1ce
1
Parent(s):
2ec0615
minor initial fixes
Browse files- app.py +0 -1
- familizer.py +137 -0
- generate.py +1 -4
- generation_utils.py +141 -0
- requirements.txt +16 -1
app.py
CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
|
|
2 |
from load import LoadModel
|
3 |
from generate import GenerateMidiText
|
4 |
from constants import INSTRUMENT_CLASSES
|
5 |
-
from encoder import MIDIEncoder
|
6 |
from decoder import TextDecoder
|
7 |
from utils import get_miditok, index_has_substring
|
8 |
from playback import get_music
|
|
|
2 |
from load import LoadModel
|
3 |
from generate import GenerateMidiText
|
4 |
from constants import INSTRUMENT_CLASSES
|
|
|
5 |
from decoder import TextDecoder
|
6 |
from utils import get_miditok, index_has_substring
|
7 |
from playback import get_music
|
familizer.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from joblib import Parallel, delayed
|
3 |
+
from pathlib import Path
|
4 |
+
from constants import INSTRUMENT_CLASSES, INSTRUMENT_TRANSFER_CLASSES
|
5 |
+
from utils import get_files, timeit, FileCompressor
|
6 |
+
|
7 |
+
|
8 |
+
class Familizer:
|
9 |
+
def __init__(self, n_jobs=-1, arbitrary=False):
|
10 |
+
self.n_jobs = n_jobs
|
11 |
+
self.reverse_family(arbitrary)
|
12 |
+
|
13 |
+
def get_family_number(self, program_number):
|
14 |
+
"""
|
15 |
+
Given a MIDI instrument number, return its associated instrument family number.
|
16 |
+
"""
|
17 |
+
for instrument_class in INSTRUMENT_CLASSES:
|
18 |
+
if program_number in instrument_class["program_range"]:
|
19 |
+
return instrument_class["family_number"]
|
20 |
+
|
21 |
+
def reverse_family(self, arbitrary):
|
22 |
+
"""
|
23 |
+
Create a dictionary of family numbers to randomly assigned program numbers.
|
24 |
+
This is used to reverse the family number tokens back to program number tokens.
|
25 |
+
"""
|
26 |
+
|
27 |
+
if arbitrary is True:
|
28 |
+
int_class = INSTRUMENT_TRANSFER_CLASSES
|
29 |
+
else:
|
30 |
+
int_class = INSTRUMENT_CLASSES
|
31 |
+
|
32 |
+
self.reference_programs = {}
|
33 |
+
for family in int_class:
|
34 |
+
self.reference_programs[family["family_number"]] = random.choice(
|
35 |
+
family["program_range"]
|
36 |
+
)
|
37 |
+
|
38 |
+
def get_program_number(self, family_number):
|
39 |
+
"""
|
40 |
+
Given given a family number return a random program number in the respective program_range.
|
41 |
+
This is the reverse operation of get_family_number.
|
42 |
+
"""
|
43 |
+
assert family_number in self.reference_programs
|
44 |
+
return self.reference_programs[family_number]
|
45 |
+
|
46 |
+
# Replace instruments in text files
|
47 |
+
def replace_instrument_token(self, token):
|
48 |
+
"""
|
49 |
+
Given a MIDI program number in a word token, replace it with the family or program
|
50 |
+
number token depending on the operation.
|
51 |
+
e.g. INST=86 -> INST=10
|
52 |
+
"""
|
53 |
+
inst_number = int(token.split("=")[1])
|
54 |
+
if self.operation == "family":
|
55 |
+
return "INST=" + str(self.get_family_number(inst_number))
|
56 |
+
elif self.operation == "program":
|
57 |
+
return "INST=" + str(self.get_program_number(inst_number))
|
58 |
+
|
59 |
+
def replace_instrument_in_text(self, text):
|
60 |
+
"""Given a text piece, replace all instrument tokens with family number tokens."""
|
61 |
+
return " ".join(
|
62 |
+
[
|
63 |
+
self.replace_instrument_token(token)
|
64 |
+
if token.startswith("INST=") and not token == "INST=DRUMS"
|
65 |
+
else token
|
66 |
+
for token in text.split(" ")
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
def replace_instruments_in_file(self, file):
|
71 |
+
"""Given a text file, replace all instrument tokens with family number tokens."""
|
72 |
+
text = file.read_text()
|
73 |
+
file.write_text(self.replace_instrument_in_text(text))
|
74 |
+
|
75 |
+
@timeit
|
76 |
+
def replace_instruments(self):
|
77 |
+
"""
|
78 |
+
Given a directory of text files:
|
79 |
+
Replace all instrument tokens with family number tokens.
|
80 |
+
"""
|
81 |
+
files = get_files(self.output_directory, extension="txt")
|
82 |
+
Parallel(n_jobs=self.n_jobs)(
|
83 |
+
delayed(self.replace_instruments_in_file)(file) for file in files
|
84 |
+
)
|
85 |
+
|
86 |
+
def replace_tokens(self, input_directory, output_directory, operation):
|
87 |
+
"""
|
88 |
+
Given a directory and an operation, perform the operation on all text files in the directory.
|
89 |
+
operation can be either 'family' or 'program'.
|
90 |
+
"""
|
91 |
+
self.input_directory = input_directory
|
92 |
+
self.output_directory = output_directory
|
93 |
+
self.operation = operation
|
94 |
+
|
95 |
+
# Uncompress files, replace tokens, compress files
|
96 |
+
fc = FileCompressor(self.input_directory, self.output_directory, self.n_jobs)
|
97 |
+
fc.unzip()
|
98 |
+
self.replace_instruments()
|
99 |
+
fc.zip()
|
100 |
+
print(self.operation + " complete.")
|
101 |
+
|
102 |
+
def to_family(self, input_directory, output_directory):
|
103 |
+
"""
|
104 |
+
Given a directory containing zip files, replace all instrument tokens with
|
105 |
+
family number tokens. The output is a directory of zip files.
|
106 |
+
"""
|
107 |
+
self.replace_tokens(input_directory, output_directory, "family")
|
108 |
+
|
109 |
+
def to_program(self, input_directory, output_directory):
|
110 |
+
"""
|
111 |
+
Given a directory containing zip files, replace all instrument tokens with
|
112 |
+
program number tokens. The output is a directory of zip files.
|
113 |
+
"""
|
114 |
+
self.replace_tokens(input_directory, output_directory, "program")
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
|
119 |
+
# Choose number of jobs for parallel processing
|
120 |
+
n_jobs = -1
|
121 |
+
|
122 |
+
# Instantiate Familizer
|
123 |
+
familizer = Familizer(n_jobs)
|
124 |
+
|
125 |
+
# Choose directory to process for program
|
126 |
+
input_directory = Path("midi/dataset/first_selection/validate").resolve() # fmt: skip
|
127 |
+
output_directory = input_directory / "family"
|
128 |
+
|
129 |
+
# familize files
|
130 |
+
familizer.to_family(input_directory, output_directory)
|
131 |
+
|
132 |
+
# Choose directory to process for family
|
133 |
+
# input_directory = Path("../data/music_picks/encoded_samples/validate/family").resolve() # fmt: skip
|
134 |
+
# output_directory = input_directory.parent / "program"
|
135 |
+
|
136 |
+
# # programize files
|
137 |
+
# familizer.to_program(input_directory, output_directory)
|
generate.py
CHANGED
@@ -1,11 +1,8 @@
|
|
1 |
from generation_utils import *
|
2 |
from utils import WriteTextMidiToFile, get_miditok
|
3 |
from load import LoadModel
|
4 |
-
from constants import INSTRUMENT_CLASSES
|
5 |
-
|
6 |
-
## import for execution
|
7 |
from decoder import TextDecoder
|
8 |
-
from playback import get_music
|
9 |
|
10 |
|
11 |
class GenerateMidiText:
|
|
|
1 |
from generation_utils import *
|
2 |
from utils import WriteTextMidiToFile, get_miditok
|
3 |
from load import LoadModel
|
|
|
|
|
|
|
4 |
from decoder import TextDecoder
|
5 |
+
from playback import get_music
|
6 |
|
7 |
|
8 |
class GenerateMidiText:
|
generation_utils.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import matplotlib
|
5 |
+
from constants import INSTRUMENT_CLASSES
|
6 |
+
|
7 |
+
# matplotlib settings
|
8 |
+
matplotlib.use("Agg") # for server
|
9 |
+
matplotlib.rcParams["xtick.major.size"] = 0
|
10 |
+
matplotlib.rcParams["ytick.major.size"] = 0
|
11 |
+
matplotlib.rcParams["axes.facecolor"] = "grey"
|
12 |
+
matplotlib.rcParams["axes.edgecolor"] = "none"
|
13 |
+
|
14 |
+
|
15 |
+
def define_generation_dir(model_repo_path):
|
16 |
+
#### to remove later ####
|
17 |
+
if model_repo_path == "models/model_2048_fake_wholedataset":
|
18 |
+
model_repo_path = "misnaej/the-jam-machine"
|
19 |
+
#### to remove later ####
|
20 |
+
generated_sequence_files_path = f"midi/generated/{model_repo_path}"
|
21 |
+
if not os.path.exists(generated_sequence_files_path):
|
22 |
+
os.makedirs(generated_sequence_files_path)
|
23 |
+
return generated_sequence_files_path
|
24 |
+
|
25 |
+
|
26 |
+
def bar_count_check(sequence, n_bars):
|
27 |
+
"""check if the sequence contains the right number of bars"""
|
28 |
+
sequence = sequence.split(" ")
|
29 |
+
# find occurences of "BAR_END" in a "sequence"
|
30 |
+
# I don't check for "BAR_START" because it is not always included in "sequence"
|
31 |
+
# e.g. BAR_START is included the prompt when generating one more bar
|
32 |
+
bar_count = 0
|
33 |
+
for seq in sequence:
|
34 |
+
if seq == "BAR_END":
|
35 |
+
bar_count += 1
|
36 |
+
bar_count_matches = bar_count == n_bars
|
37 |
+
if not bar_count_matches:
|
38 |
+
print(f"Bar count is {bar_count} - but should be {n_bars}")
|
39 |
+
return bar_count_matches, bar_count
|
40 |
+
|
41 |
+
|
42 |
+
def print_inst_classes(INSTRUMENT_CLASSES):
|
43 |
+
"""Print the instrument classes"""
|
44 |
+
for classe in INSTRUMENT_CLASSES:
|
45 |
+
print(f"{classe}")
|
46 |
+
|
47 |
+
|
48 |
+
def check_if_prompt_inst_in_tokenizer_vocab(tokenizer, inst_prompt_list):
|
49 |
+
"""Check if the prompt instrument are in the tokenizer vocab"""
|
50 |
+
for inst in inst_prompt_list:
|
51 |
+
if f"INST={inst}" not in tokenizer.vocab:
|
52 |
+
instruments_in_dataset = np.sort(
|
53 |
+
[tok.split("=")[-1] for tok in tokenizer.vocab if "INST" in tok]
|
54 |
+
)
|
55 |
+
print_inst_classes(INSTRUMENT_CLASSES)
|
56 |
+
raise ValueError(
|
57 |
+
f"""The instrument {inst} is not in the tokenizer vocabulary.
|
58 |
+
Available Instruments: {instruments_in_dataset}"""
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
|
63 |
+
"""Forcing the generated sequence to have the expected length
|
64 |
+
expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""
|
65 |
+
|
66 |
+
if bar_count - expected_length > 0: # Cut the sequence if too long
|
67 |
+
full_piece = ""
|
68 |
+
splited = generated.split("BAR_END ")
|
69 |
+
for count, spl in enumerate(splited):
|
70 |
+
if count < expected_length:
|
71 |
+
full_piece += spl + "BAR_END "
|
72 |
+
|
73 |
+
full_piece += "TRACK_END "
|
74 |
+
full_piece = input_prompt + full_piece
|
75 |
+
print(f"Generated sequence trunkated at {expected_length} bars")
|
76 |
+
bar_count_checks = True
|
77 |
+
|
78 |
+
elif bar_count - expected_length < 0: # Do nothing it the sequence if too short
|
79 |
+
full_piece = input_prompt + generated
|
80 |
+
bar_count_checks = False
|
81 |
+
print(f"--- Generated sequence is too short - Force Regeration ---")
|
82 |
+
|
83 |
+
return full_piece, bar_count_checks
|
84 |
+
|
85 |
+
|
86 |
+
def get_max_time(inst_midi):
|
87 |
+
max_time = 0
|
88 |
+
for inst in inst_midi.instruments:
|
89 |
+
max_time = max(max_time, inst.get_end_time())
|
90 |
+
return max_time
|
91 |
+
|
92 |
+
|
93 |
+
def plot_piano_roll(inst_midi):
|
94 |
+
piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
|
95 |
+
piano_roll_fig.tight_layout()
|
96 |
+
piano_roll_fig.patch.set_alpha(0.1)
|
97 |
+
inst_count = 0
|
98 |
+
beats_per_bar = 4
|
99 |
+
sec_per_beat = 0.5
|
100 |
+
next_beat = max(inst_midi.get_beats()) + np.diff(inst_midi.get_beats())[0]
|
101 |
+
bars_time = np.append(inst_midi.get_beats(), (next_beat))[::beats_per_bar].astype(
|
102 |
+
int
|
103 |
+
)
|
104 |
+
for inst in inst_midi.instruments:
|
105 |
+
inst_count += 1
|
106 |
+
plt.subplot(len(inst_midi.instruments), 1, inst_count)
|
107 |
+
|
108 |
+
for bar in bars_time:
|
109 |
+
plt.axvline(bar, color="grey", linewidth=0.5)
|
110 |
+
octaves = np.arange(0, 128, 12)
|
111 |
+
for octave in octaves:
|
112 |
+
plt.axhline(octave, color="grey", linewidth=0.5)
|
113 |
+
plt.yticks(octaves, visible=False)
|
114 |
+
|
115 |
+
p_midi_note_list = inst.notes
|
116 |
+
note_time = []
|
117 |
+
note_pitch = []
|
118 |
+
for note in p_midi_note_list:
|
119 |
+
note_time.append([note.start, note.end])
|
120 |
+
note_pitch.append([note.pitch, note.pitch])
|
121 |
+
|
122 |
+
plt.plot(
|
123 |
+
np.array(note_time).T,
|
124 |
+
np.array(note_pitch).T,
|
125 |
+
color="purple",
|
126 |
+
linewidth=3,
|
127 |
+
solid_capstyle="butt",
|
128 |
+
)
|
129 |
+
plt.ylim(0, 128)
|
130 |
+
xticks = np.array(bars_time)[:-1]
|
131 |
+
plt.tight_layout()
|
132 |
+
plt.xlim(min(bars_time), max(bars_time))
|
133 |
+
# plt.xlabel("bars")
|
134 |
+
plt.xticks(
|
135 |
+
xticks + 0.5 * beats_per_bar * sec_per_beat,
|
136 |
+
labels=xticks.argsort() + 1,
|
137 |
+
visible=False,
|
138 |
+
)
|
139 |
+
plt.title(inst.name, fontsize=10, color="white")
|
140 |
+
|
141 |
+
return piano_roll_fig
|
requirements.txt
CHANGED
@@ -2,4 +2,19 @@ gradio
|
|
2 |
matplotlib
|
3 |
sys
|
4 |
matplotlib
|
5 |
-
numpy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
matplotlib
|
3 |
sys
|
4 |
matplotlib
|
5 |
+
numpy
|
6 |
+
joblib
|
7 |
+
pathlib
|
8 |
+
random
|
9 |
+
transformers
|
10 |
+
os
|
11 |
+
miditok
|
12 |
+
librosa
|
13 |
+
pretty_midi
|
14 |
+
pydub
|
15 |
+
shutil
|
16 |
+
scipy
|
17 |
+
zipfile
|
18 |
+
time
|
19 |
+
json
|
20 |
+
datetime
|