File size: 6,577 Bytes
7edf1ce
 
 
 
725968f
6cc2135
7edf1ce
6cc2135
7edf1ce
 
 
 
 
6cc2135
 
7edf1ce
 
725968f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7edf1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e996b6d
 
 
 
 
7edf1ce
 
725968f
 
7edf1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cc2135
7edf1ce
 
 
 
 
 
 
 
6cc2135
 
 
 
 
 
 
 
7edf1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cc2135
 
7edf1ce
 
6cc2135
 
 
 
7edf1ce
 
 
 
 
 
6cc2135
7edf1ce
 
 
 
 
6cc2135
 
 
 
 
 
 
 
 
7edf1ce
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from utils import writeToFile, get_datetime

from constants import INSTRUMENT_CLASSES
from playback import get_music, show_piano_roll

# matplotlib settings
matplotlib.use("Agg")  # for server
matplotlib.rcParams["xtick.major.size"] = 0
matplotlib.rcParams["ytick.major.size"] = 0
matplotlib.rcParams["axes.facecolor"] = "none"
matplotlib.rcParams["axes.edgecolor"] = "grey"


class WriteTextMidiToFile:  # utils saving miditext from teh class GenerateMidiText to file
    def __init__(self, generate_midi, output_path):
        self.generated_midi = generate_midi.generated_piece
        self.output_path = output_path
        self.hyperparameter_and_bars = generate_midi.piece_by_track

    def hashing_seq(self):
        self.current_time = get_datetime()
        self.output_path_filename = f"{self.output_path}/{self.current_time}.json"

    def wrapping_seq_hyperparameters_in_dict(self):
        # assert type(self.generated_midi) is str, "error: generate_midi must be a string"
        # assert (
        #     type(self.hyperparameter_dict) is dict
        # ), "error: feature_dict must be a dictionnary"
        return {
            "generated_midi": self.generated_midi,
            "hyperparameters_and_bars": self.hyperparameter_and_bars,
        }

    def text_midi_to_file(self):
        self.hashing_seq()
        output_dict = self.wrapping_seq_hyperparameters_in_dict()
        print(f"Token generate_midi written: {self.output_path_filename}")
        writeToFile(self.output_path_filename, output_dict)
        return self.output_path_filename


def define_generation_dir(generation_dir):
    if not os.path.exists(generation_dir):
        os.makedirs(generation_dir)
    return generation_dir


def bar_count_check(sequence, n_bars):
    """check if the sequence contains the right number of bars"""
    sequence = sequence.split(" ")
    # find occurences of "BAR_END" in a "sequence"
    # I don't check for "BAR_START" because it is not always included in "sequence"
    # e.g. BAR_START is included the prompt when generating one more bar
    bar_count = 0
    for seq in sequence:
        if seq == "BAR_END":
            bar_count += 1
    bar_count_matches = bar_count == n_bars
    if not bar_count_matches:
        print(f"Bar count is {bar_count} - but should be {n_bars}")
    return bar_count_matches, bar_count


def print_inst_classes(INSTRUMENT_CLASSES):
    """Print the instrument classes"""
    for classe in INSTRUMENT_CLASSES:
        print(f"{classe}")


def check_if_prompt_inst_in_tokenizer_vocab(tokenizer, inst_prompt_list):
    """Check if the prompt instrument are in the tokenizer vocab"""
    for inst in inst_prompt_list:
        if f"INST={inst}" not in tokenizer.vocab:
            instruments_in_dataset = np.sort(
                [tok.split("=")[-1] for tok in tokenizer.vocab if "INST" in tok]
            )
            print_inst_classes(INSTRUMENT_CLASSES)
            raise ValueError(
                f"""The instrument {inst} is not in the tokenizer vocabulary. 
                Available Instruments: {instruments_in_dataset}"""
            )


# TODO
def check_if_prompt_density_in_tokenizer_vocab(tokenizer, density_prompt_list):
    pass


def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
    """Forcing the generated sequence to have the expected length
    expected_length and bar_count refers to the length of newly_generated_only (without input prompt)
    """

    if bar_count - expected_length > 0:  # Cut the sequence if too long
        full_piece = ""
        splited = generated.split("BAR_END ")
        for count, spl in enumerate(splited):
            if count < expected_length:
                full_piece += spl + "BAR_END "

        full_piece += "TRACK_END "
        full_piece = input_prompt + full_piece
        print(f"Generated sequence trunkated at {expected_length} bars")
        bar_count_checks = True

    elif bar_count - expected_length < 0:  # Do nothing it the sequence if too short
        full_piece = input_prompt + generated
        bar_count_checks = False
        print(f"--- Generated sequence is too short - Force Regeration ---")

    return full_piece, bar_count_checks


def get_max_time(inst_midi):
    max_time = 0
    for inst in inst_midi.instruments:
        max_time = max(max_time, inst.get_end_time())
    return max_time


def plot_piano_roll(inst_midi):
    piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
    piano_roll_fig.tight_layout()
    piano_roll_fig.patch.set_alpha(0)
    inst_count = 0
    beats_per_bar = 4
    sec_per_beat = 0.5
    next_beat = max(inst_midi.get_beats()) + np.diff(inst_midi.get_beats())[0]
    bars_time = np.append(inst_midi.get_beats(), (next_beat))[::beats_per_bar].astype(
        int
    )
    for inst in inst_midi.instruments:
        # hardcoded for now
        if inst.name == "Drums":
            color = "purple"
        elif inst.name == "Synth Bass 1":
            color = "orange"
        else:
            color = "green"

        inst_count += 1
        plt.subplot(len(inst_midi.instruments), 1, inst_count)

        for bar in bars_time:
            plt.axvline(bar, color="grey", linewidth=0.5)
        octaves = np.arange(0, 128, 12)
        for octave in octaves:
            plt.axhline(octave, color="grey", linewidth=0.5)
        plt.yticks(octaves, visible=False)

        p_midi_note_list = inst.notes
        note_time = []
        note_pitch = []
        for note in p_midi_note_list:
            note_time.append([note.start, note.end])
            note_pitch.append([note.pitch, note.pitch])
        note_pitch = np.array(note_pitch)
        note_time = np.array(note_time)

        plt.plot(
            note_time.T,
            note_pitch.T,
            color=color,
            linewidth=4,
            solid_capstyle="butt",
        )
        plt.ylim(0, 128)
        xticks = np.array(bars_time)[:-1]
        plt.tight_layout()
        plt.xlim(min(bars_time), max(bars_time))
        plt.ylim(max([note_pitch.min() - 5, 0]), note_pitch.max() + 5)
        plt.xticks(
            xticks + 0.5 * beats_per_bar * sec_per_beat,
            labels=xticks.argsort() + 1,
            visible=False,
        )
        plt.text(
            0.2,
            note_pitch.max() + 4,
            inst.name,
            fontsize=20,
            color=color,
            horizontalalignment="left",
            verticalalignment="top",
        )

    return piano_roll_fig