misnaej commited on
Commit
725968f
1 Parent(s): 5a748f4

updated generation process - epsilon

Browse files
Files changed (3) hide show
  1. familizer.py +0 -1
  2. generate.py +62 -29
  3. generation_utils.py +35 -6
familizer.py CHANGED
@@ -115,7 +115,6 @@ class Familizer:
115
 
116
 
117
  if __name__ == "__main__":
118
-
119
  # Choose number of jobs for parallel processing
120
  n_jobs = -1
121
 
115
 
116
 
117
  if __name__ == "__main__":
 
118
  # Choose number of jobs for parallel processing
119
  n_jobs = -1
120
 
generate.py CHANGED
@@ -1,8 +1,5 @@
1
  from generation_utils import *
2
- from utils import WriteTextMidiToFile, get_miditok
3
- from load import LoadModel
4
- from decoder import TextDecoder
5
- from playback import get_music
6
 
7
 
8
  class GenerateMidiText:
@@ -100,15 +97,26 @@ class GenerateMidiText:
100
  text = text.rstrip(" ").rstrip("TRACK_END")
101
  return text
102
 
103
- def get_last_generated_track(self, full_piece):
104
- track = (
105
- "TRACK_START "
106
- + self.striping_track_ends(full_piece.split("TRACK_START ")[-1])
107
- + "TRACK_END "
108
- ) # forcing the space after track and
109
  return track
110
 
111
- def get_selected_track_as_text(self, track_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  text = ""
113
  for bar in self.piece_by_track[track_id]["bars"]:
114
  text += bar
@@ -122,18 +130,12 @@ class GenerateMidiText:
122
  def get_whole_piece_from_bar_dict(self):
123
  text = "PIECE_START "
124
  for track_id, _ in enumerate(self.piece_by_track):
125
- text += self.get_selected_track_as_text(track_id)
126
  return text
127
 
128
- def delete_one_track(self, track): # TO BE TESTED
129
  self.piece_by_track.pop(track)
130
 
131
- # def update_piece_dict__add_track(self, track_id, track):
132
- # self.piece_dict[track_id] = track
133
-
134
- # def update_all_dictionnaries__add_track(self, track):
135
- # self.update_piece_dict__add_track(track_id, track)
136
-
137
  """Basic generation tools"""
138
 
139
  def tokenize_input_prompt(self, input_prompt, verbose=True):
@@ -238,10 +240,12 @@ class GenerateMidiText:
238
  )
239
  else:
240
  print('"--- Wrong length - Regenerating ---')
 
241
  if not bar_count_checks:
242
  failed += 1
243
- if failed > 2:
244
- bar_count_checks = True # TOFIX exit the while loop
 
245
 
246
  return full_piece
247
 
@@ -298,8 +302,7 @@ class GenerateMidiText:
298
 
299
  """ Piece generation - Extra Bars """
300
 
301
- @staticmethod
302
- def process_prompt_for_next_bar(self, track_idx):
303
  """Processing the prompt for the model to generate one more bar only.
304
  The prompt containts:
305
  if not the first bar: the previous, already processed, bars of the track
@@ -318,6 +321,10 @@ class GenerateMidiText:
318
  if i != track_idx:
319
  len_diff = len(othertrack["bars"]) - len(track["bars"])
320
  if len_diff > 0:
 
 
 
 
321
  # if other bars are longer, it mean that this one should catch up
322
  pre_promt += othertrack["bars"][0]
323
  for bar in track["bars"][-self.model_n_bar :]:
@@ -325,7 +332,7 @@ class GenerateMidiText:
325
  pre_promt += "TRACK_END "
326
  elif (
327
  False
328
- ): # len_diff <= 0: # THIS DOES NOT WORK - It just fills things with empty bars
329
  # adding an empty bars at the end of the other tracks if they have not been processed yet
330
  pre_promt += othertracks["bars"][0]
331
  for bar in track["bars"][-(self.model_n_bar - 1) :]:
@@ -337,27 +344,54 @@ class GenerateMidiText:
337
  # for the bar to prolong
338
  # initialization e.g TRACK_START INST=DRUMS DENSITY=2
339
  processed_prompt = track["bars"][0]
 
 
 
 
340
  for bar in track["bars"][-(self.model_n_bar - 1) :]:
341
  # adding the "last" bars of the track
342
  processed_prompt += bar
343
 
344
  processed_prompt += "BAR_START "
 
 
 
 
345
  print(
346
  f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
347
  )
 
348
  return pre_promt + processed_prompt
349
 
350
- def generate_one_more_bar(self, i):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  """Generate one more bar from the input_prompt"""
352
- processed_prompt = self.process_prompt_for_next_bar(self, i)
 
353
  prompt_plus_bar = self.generate_until_track_end(
354
  input_prompt=processed_prompt,
355
- temperature=self.piece_by_track[i]["temperature"],
356
  expected_length=1,
357
  verbose=False,
358
  )
359
  added_bar = self.get_newly_generated_bar(prompt_plus_bar)
360
- self.update_track_dict__add_bars(added_bar, i)
361
 
362
  def get_newly_generated_bar(self, prompt_plus_bar):
363
  return "BAR_START " + self.striping_track_ends(
@@ -380,7 +414,6 @@ class GenerateMidiText:
380
  self.check_the_piece_for_errors()
381
 
382
  def check_the_piece_for_errors(self, piece: str = None):
383
-
384
  if piece is None:
385
  piece = self.get_whole_piece_from_bar_dict()
386
  errors = []
1
  from generation_utils import *
2
+ import random
 
 
 
3
 
4
 
5
  class GenerateMidiText:
97
  text = text.rstrip(" ").rstrip("TRACK_END")
98
  return text
99
 
100
+ def get_last_generated_track(self, piece):
101
+ """Get the last track from a piece written as a single long string"""
102
+ track = self.get_tracks_from_a_piece(piece)[-1]
 
 
 
103
  return track
104
 
105
+ def get_tracks_from_a_piece(self, piece):
106
+ """Get all the tracks from a piece written as a single long string"""
107
+ all_tracks = [
108
+ "TRACK_START " + the_track + "TRACK_END "
109
+ for the_track in self.striping_track_ends(piece.split("TRACK_START ")[1::])
110
+ ]
111
+ return all_tracks
112
+
113
+ def get_piece_from_track_list(self, track_list):
114
+ piece = "PIECE_START "
115
+ for track in track_list:
116
+ piece += track
117
+ return piece
118
+
119
+ def get_whole_track_from_bar_dict(self, track_id):
120
  text = ""
121
  for bar in self.piece_by_track[track_id]["bars"]:
122
  text += bar
130
  def get_whole_piece_from_bar_dict(self):
131
  text = "PIECE_START "
132
  for track_id, _ in enumerate(self.piece_by_track):
133
+ text += self.get_whole_track_from_bar_dict(track_id)
134
  return text
135
 
136
+ def delete_one_track(self, track):
137
  self.piece_by_track.pop(track)
138
 
 
 
 
 
 
 
139
  """Basic generation tools"""
140
 
141
  def tokenize_input_prompt(self, input_prompt, verbose=True):
240
  )
241
  else:
242
  print('"--- Wrong length - Regenerating ---')
243
+
244
  if not bar_count_checks:
245
  failed += 1
246
+
247
+ if failed > 2:
248
+ bar_count_checks = True # exit the while loop if failed too much
249
 
250
  return full_piece
251
 
302
 
303
  """ Piece generation - Extra Bars """
304
 
305
+ def process_prompt_for_next_bar(self, track_idx, verbose=True):
 
306
  """Processing the prompt for the model to generate one more bar only.
307
  The prompt containts:
308
  if not the first bar: the previous, already processed, bars of the track
321
  if i != track_idx:
322
  len_diff = len(othertrack["bars"]) - len(track["bars"])
323
  if len_diff > 0:
324
+ if verbose:
325
+ print(
326
+ f"Adding bars - {len(track['bars'][-self.model_n_bar :])} selected from SIDE track: {i} for prompt"
327
+ )
328
  # if other bars are longer, it mean that this one should catch up
329
  pre_promt += othertrack["bars"][0]
330
  for bar in track["bars"][-self.model_n_bar :]:
332
  pre_promt += "TRACK_END "
333
  elif (
334
  False
335
+ ): # len_diff <= 0: # THIS DOES NOT WORK - It just adds empty bars
336
  # adding an empty bars at the end of the other tracks if they have not been processed yet
337
  pre_promt += othertracks["bars"][0]
338
  for bar in track["bars"][-(self.model_n_bar - 1) :]:
344
  # for the bar to prolong
345
  # initialization e.g TRACK_START INST=DRUMS DENSITY=2
346
  processed_prompt = track["bars"][0]
347
+ if verbose:
348
+ print(
349
+ f"Adding bars - {len(track['bars'][-(self.model_n_bar - 1) :])} selected from MAIN track: {track_idx} for prompt"
350
+ )
351
  for bar in track["bars"][-(self.model_n_bar - 1) :]:
352
  # adding the "last" bars of the track
353
  processed_prompt += bar
354
 
355
  processed_prompt += "BAR_START "
356
+
357
+ # making the preprompt short enought to avoid bug due to length of the prompt (model limitation)
358
+ pre_promt = self.force_prompt_length(pre_promt, 1500)
359
+
360
  print(
361
  f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
362
  )
363
+
364
  return pre_promt + processed_prompt
365
 
366
+ def force_prompt_length(self, prompt, expected_length):
367
+ """remove one instrument/track from the prompt it too long
368
+ Args:
369
+ prompt (str): the prompt to be processed
370
+ expected_length (int): the expected length of the prompt
371
+ Returns:
372
+ the truncated prompt"""
373
+ if len(prompt.split(" ")) < expected_length:
374
+ truncated_prompt = prompt
375
+ else:
376
+ tracks = self.get_tracks_from_a_piece(prompt)
377
+ selected_tracks = random.sample(tracks, len(tracks) - 1)
378
+ truncated_prompt = self.get_piece_from_track_list(selected_tracks)
379
+ print(f"Prompt too long - deleting one track")
380
+
381
+ return truncated_prompt
382
+
383
+ def generate_one_more_bar(self, track_index):
384
  """Generate one more bar from the input_prompt"""
385
+ processed_prompt = self.process_prompt_for_next_bar(track_index)
386
+
387
  prompt_plus_bar = self.generate_until_track_end(
388
  input_prompt=processed_prompt,
389
+ temperature=self.piece_by_track[track_index]["temperature"],
390
  expected_length=1,
391
  verbose=False,
392
  )
393
  added_bar = self.get_newly_generated_bar(prompt_plus_bar)
394
+ self.update_track_dict__add_bars(added_bar, track_index)
395
 
396
  def get_newly_generated_bar(self, prompt_plus_bar):
397
  return "BAR_START " + self.striping_track_ends(
414
  self.check_the_piece_for_errors()
415
 
416
  def check_the_piece_for_errors(self, piece: str = None):
 
417
  if piece is None:
418
  piece = self.get_whole_piece_from_bar_dict()
419
  errors = []
generation_utils.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import matplotlib
 
5
 
6
  from constants import INSTRUMENT_CLASSES
7
  from playback import get_music, show_piano_roll
@@ -14,11 +15,38 @@ matplotlib.rcParams["axes.facecolor"] = "none"
14
  matplotlib.rcParams["axes.edgecolor"] = "grey"
15
 
16
 
17
- def define_generation_dir(model_repo_path):
18
- generated_sequence_files_path = f"midi/generated/{model_repo_path}"
19
- if not os.path.exists(generated_sequence_files_path):
20
- os.makedirs(generated_sequence_files_path)
21
- return generated_sequence_files_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  def bar_count_check(sequence, n_bars):
@@ -64,7 +92,8 @@ def check_if_prompt_density_in_tokenizer_vocab(tokenizer, density_prompt_list):
64
 
65
  def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
66
  """Forcing the generated sequence to have the expected length
67
- expected_length and bar_count refers to the length of newly_generated_only (without input prompt)"""
 
68
 
69
  if bar_count - expected_length > 0: # Cut the sequence if too long
70
  full_piece = ""
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import matplotlib
5
+ from utils import writeToFile, get_datetime
6
 
7
  from constants import INSTRUMENT_CLASSES
8
  from playback import get_music, show_piano_roll
15
  matplotlib.rcParams["axes.edgecolor"] = "grey"
16
 
17
 
18
+ class WriteTextMidiToFile: # utils saving miditext from teh class GenerateMidiText to file
19
+ def __init__(self, generate_midi, output_path):
20
+ self.generated_midi = generate_midi.generated_piece
21
+ self.output_path = output_path
22
+ self.hyperparameter_and_bars = generate_midi.piece_by_track
23
+
24
+ def hashing_seq(self):
25
+ self.current_time = get_datetime()
26
+ self.output_path_filename = f"{self.output_path}/{self.current_time}.json"
27
+
28
+ def wrapping_seq_hyperparameters_in_dict(self):
29
+ # assert type(self.generated_midi) is str, "error: generate_midi must be a string"
30
+ # assert (
31
+ # type(self.hyperparameter_dict) is dict
32
+ # ), "error: feature_dict must be a dictionnary"
33
+ return {
34
+ "generated_midi": self.generated_midi,
35
+ "hyperparameters_and_bars": self.hyperparameter_and_bars,
36
+ }
37
+
38
+ def text_midi_to_file(self):
39
+ self.hashing_seq()
40
+ output_dict = self.wrapping_seq_hyperparameters_in_dict()
41
+ print(f"Token generate_midi written: {self.output_path_filename}")
42
+ writeToFile(self.output_path_filename, output_dict)
43
+ return self.output_path_filename
44
+
45
+
46
+ def define_generation_dir(generation_dir):
47
+ if not os.path.exists(generation_dir):
48
+ os.makedirs(generation_dir)
49
+ return generation_dir
50
 
51
 
52
  def bar_count_check(sequence, n_bars):
92
 
93
  def forcing_bar_count(input_prompt, generated, bar_count, expected_length):
94
  """Forcing the generated sequence to have the expected length
95
+ expected_length and bar_count refers to the length of newly_generated_only (without input prompt)
96
+ """
97
 
98
  if bar_count - expected_length > 0: # Cut the sequence if too long
99
  full_piece = ""