|
import os |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
|
|
from constants import INSTRUMENT_CLASSES |
|
from playback import get_music, show_piano_roll |
|
|
|
|
|
matplotlib.use("Agg") |
|
matplotlib.rcParams["xtick.major.size"] = 0 |
|
matplotlib.rcParams["ytick.major.size"] = 0 |
|
matplotlib.rcParams["axes.facecolor"] = "none" |
|
matplotlib.rcParams["axes.edgecolor"] = "grey" |
|
|
|
|
|
def define_generation_dir(model_repo_path): |
|
generated_sequence_files_path = f"midi/generated/{model_repo_path}" |
|
if not os.path.exists(generated_sequence_files_path): |
|
os.makedirs(generated_sequence_files_path) |
|
return generated_sequence_files_path |
|
|
|
|
|
def bar_count_check(sequence, n_bars): |
|
"""check if the sequence contains the right number of bars""" |
|
sequence = sequence.split(" ") |
|
|
|
|
|
|
|
bar_count = 0 |
|
for seq in sequence: |
|
if seq == "BAR_END": |
|
bar_count += 1 |
|
bar_count_matches = bar_count == n_bars |
|
if not bar_count_matches: |
|
print(f"Bar count is {bar_count} - but should be {n_bars}") |
|
return bar_count_matches, bar_count |
|
|
|
|
|
def print_inst_classes(INSTRUMENT_CLASSES): |
|
"""Print the instrument classes""" |
|
for classe in INSTRUMENT_CLASSES: |
|
print(f"{classe}") |
|
|
|
|
|
def check_if_prompt_inst_in_tokenizer_vocab(tokenizer, inst_prompt_list): |
|
"""Check if the prompt instrument are in the tokenizer vocab""" |
|
for inst in inst_prompt_list: |
|
if f"INST={inst}" not in tokenizer.vocab: |
|
instruments_in_dataset = np.sort( |
|
[tok.split("=")[-1] for tok in tokenizer.vocab if "INST" in tok] |
|
) |
|
print_inst_classes(INSTRUMENT_CLASSES) |
|
raise ValueError( |
|
f"""The instrument {inst} is not in the tokenizer vocabulary. |
|
Available Instruments: {instruments_in_dataset}""" |
|
) |
|
|
|
|
|
|
|
def check_if_prompt_density_in_tokenizer_vocab(tokenizer, density_prompt_list): |
|
pass |
|
|
|
|
|
def forcing_bar_count(input_prompt, generated, bar_count, expected_length): |
|
"""Forcing the generated sequence to have the expected length |
|
expected_length and bar_count refers to the length of newly_generated_only (without input prompt)""" |
|
|
|
if bar_count - expected_length > 0: |
|
full_piece = "" |
|
splited = generated.split("BAR_END ") |
|
for count, spl in enumerate(splited): |
|
if count < expected_length: |
|
full_piece += spl + "BAR_END " |
|
|
|
full_piece += "TRACK_END " |
|
full_piece = input_prompt + full_piece |
|
print(f"Generated sequence trunkated at {expected_length} bars") |
|
bar_count_checks = True |
|
|
|
elif bar_count - expected_length < 0: |
|
full_piece = input_prompt + generated |
|
bar_count_checks = False |
|
print(f"--- Generated sequence is too short - Force Regeration ---") |
|
|
|
return full_piece, bar_count_checks |
|
|
|
|
|
def get_max_time(inst_midi): |
|
max_time = 0 |
|
for inst in inst_midi.instruments: |
|
max_time = max(max_time, inst.get_end_time()) |
|
return max_time |
|
|
|
|
|
def plot_piano_roll(inst_midi): |
|
piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments))) |
|
piano_roll_fig.tight_layout() |
|
piano_roll_fig.patch.set_alpha(0) |
|
inst_count = 0 |
|
beats_per_bar = 4 |
|
sec_per_beat = 0.5 |
|
next_beat = max(inst_midi.get_beats()) + np.diff(inst_midi.get_beats())[0] |
|
bars_time = np.append(inst_midi.get_beats(), (next_beat))[::beats_per_bar].astype( |
|
int |
|
) |
|
for inst in inst_midi.instruments: |
|
|
|
if inst.name == "Drums": |
|
color = "purple" |
|
elif inst.name == "Synth Bass 1": |
|
color = "orange" |
|
else: |
|
color = "green" |
|
|
|
inst_count += 1 |
|
plt.subplot(len(inst_midi.instruments), 1, inst_count) |
|
|
|
for bar in bars_time: |
|
plt.axvline(bar, color="grey", linewidth=0.5) |
|
octaves = np.arange(0, 128, 12) |
|
for octave in octaves: |
|
plt.axhline(octave, color="grey", linewidth=0.5) |
|
plt.yticks(octaves, visible=False) |
|
|
|
p_midi_note_list = inst.notes |
|
note_time = [] |
|
note_pitch = [] |
|
for note in p_midi_note_list: |
|
note_time.append([note.start, note.end]) |
|
note_pitch.append([note.pitch, note.pitch]) |
|
note_pitch = np.array(note_pitch) |
|
note_time = np.array(note_time) |
|
|
|
plt.plot( |
|
note_time.T, |
|
note_pitch.T, |
|
color=color, |
|
linewidth=4, |
|
solid_capstyle="butt", |
|
) |
|
plt.ylim(0, 128) |
|
xticks = np.array(bars_time)[:-1] |
|
plt.tight_layout() |
|
plt.xlim(min(bars_time), max(bars_time)) |
|
plt.ylim(max([note_pitch.min() - 5, 0]), note_pitch.max() + 5) |
|
plt.xticks( |
|
xticks + 0.5 * beats_per_bar * sec_per_beat, |
|
labels=xticks.argsort() + 1, |
|
visible=False, |
|
) |
|
plt.text( |
|
0.2, |
|
note_pitch.max() + 4, |
|
inst.name, |
|
fontsize=20, |
|
color=color, |
|
horizontalalignment="left", |
|
verticalalignment="top", |
|
) |
|
|
|
return piano_roll_fig |
|
|