m41w4r3.exe commited on
Commit
6cc2135
1 Parent(s): facf84e

fix genesis caching

Browse files
Files changed (4) hide show
  1. decoder.py +1 -1
  2. generate.py +4 -4
  3. generation_utils.py +29 -9
  4. playground.py +56 -38
decoder.py CHANGED
@@ -178,7 +178,7 @@ class TextDecoder:
178
  inst = 0
179
  is_drum = 1
180
  if self.familized:
181
- inst = Familizer(arbitrary=True).get_program_number(int(inst)) + 1
182
  instruments.append((int(inst), is_drum))
183
  return tuple(instruments)
184
 
 
178
  inst = 0
179
  is_drum = 1
180
  if self.familized:
181
+ inst = Familizer(arbitrary=True).get_program_number(int(inst))
182
  instruments.append((int(inst), is_drum))
183
  return tuple(instruments)
184
 
generate.py CHANGED
@@ -21,12 +21,12 @@ class GenerateMidiText:
21
  - self.process_prompt_for_next_bar()
22
  - self.generate_until_track_end()"""
23
 
24
- def __init__(self, model, tokenizer):
25
  self.model = model
26
  self.tokenizer = tokenizer
27
  # default initialization
28
  self.initialize_default_parameters()
29
- self.initialize_dictionaries()
30
 
31
  """Setters"""
32
 
@@ -38,8 +38,8 @@ class GenerateMidiText:
38
  self.set_nb_bars_generated()
39
  self.set_improvisation_level(0)
40
 
41
- def initialize_dictionaries(self):
42
- self.piece_by_track = []
43
 
44
  def set_device(self, device="cpu"):
45
  self.device = ("cpu",)
 
21
  - self.process_prompt_for_next_bar()
22
  - self.generate_until_track_end()"""
23
 
24
+ def __init__(self, model, tokenizer, piece_by_track=[]):
25
  self.model = model
26
  self.tokenizer = tokenizer
27
  # default initialization
28
  self.initialize_default_parameters()
29
+ self.initialize_dictionaries(piece_by_track)
30
 
31
  """Setters"""
32
 
 
38
  self.set_nb_bars_generated()
39
  self.set_improvisation_level(0)
40
 
41
+ def initialize_dictionaries(self, piece_by_track):
42
+ self.piece_by_track = piece_by_track
43
 
44
  def set_device(self, device="cpu"):
45
  self.device = ("cpu",)
generation_utils.py CHANGED
@@ -2,14 +2,16 @@ import os
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import matplotlib
 
5
  from constants import INSTRUMENT_CLASSES
 
6
 
7
  # matplotlib settings
8
  matplotlib.use("Agg") # for server
9
  matplotlib.rcParams["xtick.major.size"] = 0
10
  matplotlib.rcParams["ytick.major.size"] = 0
11
- matplotlib.rcParams["axes.facecolor"] = "grey"
12
- matplotlib.rcParams["axes.edgecolor"] = "none"
13
 
14
 
15
  def define_generation_dir(model_repo_path):
@@ -93,7 +95,7 @@ def get_max_time(inst_midi):
93
  def plot_piano_roll(inst_midi):
94
  piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
95
  piano_roll_fig.tight_layout()
96
- piano_roll_fig.patch.set_alpha(0.1)
97
  inst_count = 0
98
  beats_per_bar = 4
99
  sec_per_beat = 0.5
@@ -102,6 +104,14 @@ def plot_piano_roll(inst_midi):
102
  int
103
  )
104
  for inst in inst_midi.instruments:
 
 
 
 
 
 
 
 
105
  inst_count += 1
106
  plt.subplot(len(inst_midi.instruments), 1, inst_count)
107
 
@@ -118,24 +128,34 @@ def plot_piano_roll(inst_midi):
118
  for note in p_midi_note_list:
119
  note_time.append([note.start, note.end])
120
  note_pitch.append([note.pitch, note.pitch])
 
 
121
 
122
  plt.plot(
123
- np.array(note_time).T,
124
- np.array(note_pitch).T,
125
- color="purple",
126
- linewidth=3,
127
  solid_capstyle="butt",
128
  )
129
  plt.ylim(0, 128)
130
  xticks = np.array(bars_time)[:-1]
131
  plt.tight_layout()
132
  plt.xlim(min(bars_time), max(bars_time))
133
- # plt.xlabel("bars")
134
  plt.xticks(
135
  xticks + 0.5 * beats_per_bar * sec_per_beat,
136
  labels=xticks.argsort() + 1,
137
  visible=False,
138
  )
139
- plt.title(inst.name, fontsize=10, color="white")
 
 
 
 
 
 
 
 
140
 
141
  return piano_roll_fig
 
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import matplotlib
5
+
6
  from constants import INSTRUMENT_CLASSES
7
+ from playback import get_music, show_piano_roll
8
 
9
  # matplotlib settings
10
  matplotlib.use("Agg") # for server
11
  matplotlib.rcParams["xtick.major.size"] = 0
12
  matplotlib.rcParams["ytick.major.size"] = 0
13
+ matplotlib.rcParams["axes.facecolor"] = "none"
14
+ matplotlib.rcParams["axes.edgecolor"] = "grey"
15
 
16
 
17
  def define_generation_dir(model_repo_path):
 
95
  def plot_piano_roll(inst_midi):
96
  piano_roll_fig = plt.figure(figsize=(25, 3 * len(inst_midi.instruments)))
97
  piano_roll_fig.tight_layout()
98
+ piano_roll_fig.patch.set_alpha(0)
99
  inst_count = 0
100
  beats_per_bar = 4
101
  sec_per_beat = 0.5
 
104
  int
105
  )
106
  for inst in inst_midi.instruments:
107
+ # hardcoded for now
108
+ if inst.name == "Drums":
109
+ color = "purple"
110
+ elif inst.name == "Synth Bass 1":
111
+ color = "orange"
112
+ else:
113
+ color = "green"
114
+
115
  inst_count += 1
116
  plt.subplot(len(inst_midi.instruments), 1, inst_count)
117
 
 
128
  for note in p_midi_note_list:
129
  note_time.append([note.start, note.end])
130
  note_pitch.append([note.pitch, note.pitch])
131
+ note_pitch = np.array(note_pitch)
132
+ note_time = np.array(note_time)
133
 
134
  plt.plot(
135
+ note_time.T,
136
+ note_pitch.T,
137
+ color=color,
138
+ linewidth=4,
139
  solid_capstyle="butt",
140
  )
141
  plt.ylim(0, 128)
142
  xticks = np.array(bars_time)[:-1]
143
  plt.tight_layout()
144
  plt.xlim(min(bars_time), max(bars_time))
145
+ plt.ylim(max([note_pitch.min() - 5, 0]), note_pitch.max() + 5)
146
  plt.xticks(
147
  xticks + 0.5 * beats_per_bar * sec_per_beat,
148
  labels=xticks.argsort() + 1,
149
  visible=False,
150
  )
151
+ plt.text(
152
+ 0.2,
153
+ note_pitch.max() + 4,
154
+ inst.name,
155
+ fontsize=20,
156
+ color=color,
157
+ horizontalalignment="left",
158
+ verticalalignment="top",
159
+ )
160
 
161
  return piano_roll_fig
playground.py CHANGED
@@ -26,7 +26,6 @@ model, tokenizer = LoadModel(
26
  model_repo, from_huggingface=True, revision=revision
27
  ).load_model_and_tokenizer()
28
 
29
-
30
  miditok = get_miditok()
31
  decoder = TextDecoder(miditok)
32
 
@@ -40,32 +39,49 @@ def define_prompt(state, genesis):
40
 
41
 
42
  def generator(
43
- regenerate, temp, density, instrument, state, add_bars=False, add_bar_count=1
 
 
 
 
 
 
 
 
44
  ):
45
 
 
 
46
  inst = next(
47
  (inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
48
  {"family_number": "DRUMS"},
49
  )["family_number"]
50
 
51
- inst_index = index_has_substring(state, "INST=" + str(inst))
52
-
53
- # Regenerate
54
- if regenerate:
55
- state.pop(inst_index)
56
- genesis.delete_one_track(inst_index)
57
- generated_text = (
58
- genesis.get_whole_piece_from_bar_dict()
59
- ) # maybe not useful here
60
- inst_index = -1 # reset to last generated
61
 
62
  # Generate
63
  if not add_bars:
 
 
 
 
 
 
 
 
 
 
64
  # NEW TRACK
65
  input_prompt = define_prompt(state, genesis)
66
  generated_text = genesis.generate_one_new_track(
67
  inst, density, temp, input_prompt=input_prompt
68
  )
 
 
69
  else:
70
  # NEW BARS
71
  genesis.generate_n_more_bars(add_bar_count) # for all instruments
@@ -79,14 +95,23 @@ def generator(
79
  decoder.get_midi(inst_text, inst_midi_name)
80
  _, inst_audio = get_music(inst_midi_name)
81
  piano_roll = plot_piano_roll(mixed_inst_midi)
82
- state.append(inst_text)
83
-
84
- return inst_text, (44100, inst_audio), piano_roll, state, (44100, mixed_audio)
85
-
 
 
 
 
 
 
 
 
86
 
87
- def instrument_row(default_inst):
88
 
 
89
  with gr.Row():
 
90
  with gr.Column(scale=1, min_width=50):
91
  inst = gr.Dropdown(
92
  [inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
@@ -100,35 +125,33 @@ def instrument_row(default_inst):
100
  output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
101
  with gr.Column(scale=1, min_width=100):
102
  inst_audio = gr.Audio(label="Audio")
103
- regenerate = gr.Checkbox(value=False, label="Regenerate")
104
  # add_bars = gr.Checkbox(value=False, label="Add Bars")
105
  # add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
106
  gen_btn = gr.Button("Generate")
107
  gen_btn.click(
108
  fn=generator,
109
- inputs=[
110
- regenerate,
111
- temp,
112
- density,
113
- inst,
114
  state,
 
 
 
115
  ],
116
- outputs=[output_txt, inst_audio, piano_roll, state, mixed_audio],
117
  )
118
 
119
 
120
- with gr.Blocks(cache_examples=False) as demo:
121
- genesis = GenerateMidiText(
122
- model,
123
- tokenizer,
124
- )
125
- genesis.set_nb_bars_generated(n_bars=n_bar_generated)
126
  state = gr.State([])
127
  mixed_audio = gr.Audio(label="Mixed Audio")
128
  piano_roll = gr.Plot(label="Piano Roll")
129
- instrument_row("Drums")
130
- instrument_row("Bass")
131
- instrument_row("Synth Lead")
132
  # instrument_row("Piano")
133
 
134
  demo.launch(debug=True)
@@ -138,14 +161,9 @@ TODO: DEPLOY
138
  TODO: temp file situation
139
  TODO: clear cache situation
140
  TODO: reset button
141
- TODO: instrument mapping business
142
- TODO: Y lim axis of piano roll
143
  TODO: add a button to save the generated midi
144
  TODO: add improvise button
145
- TODO: making the piano roll fit on the horizontal scale
146
  TODO: set values for temperature as it is done for density
147
- TODO: set the color situation to be dark background
148
- TODO: make regeration default when an intrument has already been track has already been generated
149
  TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
150
  TODO: row height to fix
151
 
 
26
  model_repo, from_huggingface=True, revision=revision
27
  ).load_model_and_tokenizer()
28
 
 
29
  miditok = get_miditok()
30
  decoder = TextDecoder(miditok)
31
 
 
39
 
40
 
41
  def generator(
42
+ label,
43
+ regenerate,
44
+ temp,
45
+ density,
46
+ instrument,
47
+ state,
48
+ piece_by_track,
49
+ add_bars=False,
50
+ add_bar_count=1,
51
  ):
52
 
53
+ genesis = GenerateMidiText(model, tokenizer, piece_by_track)
54
+ track = {"label": label}
55
  inst = next(
56
  (inst for inst in INSTRUMENT_CLASSES if inst["name"] == instrument),
57
  {"family_number": "DRUMS"},
58
  )["family_number"]
59
 
60
+ inst_index = -1 # default to last generated
61
+ if state != []:
62
+ for index, instrum in enumerate(state):
63
+ if instrum["label"] == track["label"]:
64
+ inst_index = index # changing if exists
 
 
 
 
 
65
 
66
  # Generate
67
  if not add_bars:
68
+ # Regenerate
69
+ if regenerate:
70
+ state.pop(inst_index)
71
+ genesis.delete_one_track(inst_index)
72
+
73
+ generated_text = (
74
+ genesis.get_whole_piece_from_bar_dict()
75
+ ) # maybe not useful here
76
+ inst_index = -1 # reset to last generated
77
+
78
  # NEW TRACK
79
  input_prompt = define_prompt(state, genesis)
80
  generated_text = genesis.generate_one_new_track(
81
  inst, density, temp, input_prompt=input_prompt
82
  )
83
+
84
+ regenerate = True # set generate to true
85
  else:
86
  # NEW BARS
87
  genesis.generate_n_more_bars(add_bar_count) # for all instruments
 
95
  decoder.get_midi(inst_text, inst_midi_name)
96
  _, inst_audio = get_music(inst_midi_name)
97
  piano_roll = plot_piano_roll(mixed_inst_midi)
98
+ track["text"] = inst_text
99
+ state.append(track)
100
+
101
+ return (
102
+ inst_text,
103
+ (44100, inst_audio),
104
+ piano_roll,
105
+ state,
106
+ (44100, mixed_audio),
107
+ regenerate,
108
+ genesis.piece_by_track,
109
+ )
110
 
 
111
 
112
+ def instrument_row(default_inst, row_id):
113
  with gr.Row():
114
+ row = gr.Variable(row_id)
115
  with gr.Column(scale=1, min_width=50):
116
  inst = gr.Dropdown(
117
  [inst["name"] for inst in INSTRUMENT_CLASSES] + ["Drums"],
 
125
  output_txt = gr.Textbox(label="output", lines=10, max_lines=10)
126
  with gr.Column(scale=1, min_width=100):
127
  inst_audio = gr.Audio(label="Audio")
128
+ regenerate = gr.Checkbox(value=False, label="Regenerate", visible=False)
129
  # add_bars = gr.Checkbox(value=False, label="Add Bars")
130
  # add_bar_count = gr.Dropdown([1, 2, 4, 8], value=1, label="Add Bars")
131
  gen_btn = gr.Button("Generate")
132
  gen_btn.click(
133
  fn=generator,
134
+ inputs=[row, regenerate, temp, density, inst, state, piece_by_track],
135
+ outputs=[
136
+ output_txt,
137
+ inst_audio,
138
+ piano_roll,
139
  state,
140
+ mixed_audio,
141
+ regenerate,
142
+ piece_by_track,
143
  ],
 
144
  )
145
 
146
 
147
+ with gr.Blocks() as demo:
148
+ piece_by_track = gr.State([])
 
 
 
 
149
  state = gr.State([])
150
  mixed_audio = gr.Audio(label="Mixed Audio")
151
  piano_roll = gr.Plot(label="Piano Roll")
152
+ instrument_row("Drums", 0)
153
+ instrument_row("Bass", 1)
154
+ instrument_row("Synth Lead", 2)
155
  # instrument_row("Piano")
156
 
157
  demo.launch(debug=True)
 
161
  TODO: temp file situation
162
  TODO: clear cache situation
163
  TODO: reset button
 
 
164
  TODO: add a button to save the generated midi
165
  TODO: add improvise button
 
166
  TODO: set values for temperature as it is done for density
 
 
167
  TODO: Add bar should be now set for the whole piece - regenerrate should regenerate the added bars only on all instruments
168
  TODO: row height to fix
169