File size: 5,010 Bytes
7edf1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from constants import INSTRUMENT_CLASSES

# matplotlib settings
matplotlib.use("Agg")  # for server
matplotlib.rcParams["xtick.major.size"] = 0
matplotlib.rcParams["ytick.major.size"] = 0
matplotlib.rcParams["axes.facecolor"] = "grey"
matplotlib.rcParams["axes.edgecolor"] = "none"


def define_generation_dir(model_repo_path):
    #### to remove later ####
    if model_repo_path == "models/model_2048_fake_wholedataset":
        model_repo_path = "misnaej/the-jam-machine"
    #### to remove later ####
    generated_sequence_files_path = f"midi/generated/{model_repo_path}"
    if not os.path.exists(generated_sequence_files_path):
        os.makedirs(generated_sequence_files_path)
    return generated_sequence_files_path


def bar_count_check(sequence, n_bars):
    """check if the sequence contains the right number of bars"""
    sequence = sequence.split(" ")
    # find occurences of "BAR_END" in a "sequence"
    # I don't check for "BAR_START" because it is not always included in "sequence"
    # e.g. BAR_START is included the prompt when generating one more bar
    bar_count = 0
    for seq in sequence:
        if seq == "BAR_END":
            bar_count += 1
    bar_count_matches = bar_count == n_bars
    if not bar_count_matches:
        print(f"Bar count is {bar_count} - but should be {n_bars}")
    return bar_count_matches, bar_count


def print_inst_classes(INSTRUMENT_CLASSES):
    """Print the instrument classes"""
    for classe in INSTRUMENT_CLASSES:
        print(f"{classe}")


def check_if_prompt_inst_in_tokenizer_vocab(tokenizer, inst_prompt_list):
    """Check if the prompt instrument are in the tokenizer vocab"""
    for inst in inst_prompt_list:
        if f"INST={inst}" not in tokenizer.vocab:
            instruments_in_dataset = np.sort(
                [tok.split("=")[-1] for tok in tokenizer.vocab if "INST" in tok]
            )
            print_inst_classes(INSTRUMENT_CLASSES)
            raise ValueError(
                f"""The instrument {inst} is not in the tokenizer vocabulary. 
                Available Instruments: {instruments_in_dataset}"""
            )


def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
    """Forcing the generated sequence to have the expected length
    expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""

    if bar_count - expected_length > 0:  # Cut the sequence if too long
        full_piece = ""
        splited = generated.split("BAR_END ")
        for count, spl in enumerate(splited):
            if count < expected_length:
                full_piece += spl + "BAR_END "

        full_piece += "TRACK_END "
        full_piece = input_prompt + full_piece
        print(f"Generated sequence trunkated at {expected_length} bars")
        bar_count_checks = True

    elif bar_count - expected_length < 0:  # Do nothing it the sequence if too short
        full_piece = input_prompt + generated
        bar_count_checks = False
        print(f"--- Generated sequence is too short - Force Regeration ---")

    return full_piece, bar_count_checks


def get_max_time(inst_midi):
    max_time = 0
    for inst in inst_midi.instruments:
        max_time = max(max_time, inst.get_end_time())
    return max_time


def plot_piano_roll(inst_midi):
    piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
    piano_roll_fig.tight_layout()
    piano_roll_fig.patch.set_alpha(0.1)
    inst_count = 0
    beats_per_bar = 4
    sec_per_beat = 0.5
    next_beat = max(inst_midi.get_beats()) + np.diff(inst_midi.get_beats())[0]
    bars_time = np.append(inst_midi.get_beats(), (next_beat))[::beats_per_bar].astype(
        int
    )
    for inst in inst_midi.instruments:
        inst_count += 1
        plt.subplot(len(inst_midi.instruments), 1, inst_count)

        for bar in bars_time:
            plt.axvline(bar, color="grey", linewidth=0.5)
        octaves = np.arange(0, 128, 12)
        for octave in octaves:
            plt.axhline(octave, color="grey", linewidth=0.5)
        plt.yticks(octaves, visible=False)

        p_midi_note_list = inst.notes
        note_time = []
        note_pitch = []
        for note in p_midi_note_list:
            note_time.append([note.start, note.end])
            note_pitch.append([note.pitch, note.pitch])

        plt.plot(
            np.array(note_time).T,
            np.array(note_pitch).T,
            color="purple",
            linewidth=3,
            solid_capstyle="butt",
        )
        plt.ylim(0, 128)
        xticks = np.array(bars_time)[:-1]
        plt.tight_layout()
        plt.xlim(min(bars_time), max(bars_time))
        # plt.xlabel("bars")
        plt.xticks(
            xticks + 0.5 * beats_per_bar * sec_per_beat,
            labels=xticks.argsort() + 1,
            visible=False,
        )
        plt.title(inst.name, fontsize=10, color="white")

    return piano_roll_fig