MikeMpapa commited on
Commit
c9d7e0a
1 Parent(s): 89b0ea4

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +87 -20
utils.py CHANGED
@@ -11,14 +11,19 @@ from constants import GM_INSTRUMENTS, SAMPLE_RATE
11
  from string_to_notes import token_sequence_to_note_sequence
12
  from model import get_model_and_tokenizer
13
 
 
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
  # Load the tokenizer and the model
18
  model, tokenizer = get_model_and_tokenizer()
19
 
 
 
 
20
 
21
- def create_seed_string(genre: str = "OTHER", artist: str = "OTHER") -> str:
 
22
  """
23
  Creates a seed string for generating a new piece.
24
 
@@ -29,13 +34,13 @@ def create_seed_string(genre: str = "OTHER", artist: str = "OTHER") -> str:
29
  str: The seed string.
30
  """
31
  if genre == "RANDOM" and artist == "RANDOM":
32
- seed_string = "PIECE_START"
33
  elif genre == "RANDOM" and artist != "RANDOM":
34
- seed_string = f"PIECE_START GENRE=RANDOM ARTIST={artist} TRACK_START"
35
  elif genre != "RANDOM" and artist == "RANDOM":
36
- seed_string = f"PIECE_START GENRE={genre} ARTIST=RANDOM TRACK_START"
37
  else:
38
- seed_string = f"PIECE_START GENRE={genre} ARTIST={artist} TRACK_START"
39
  return seed_string
40
 
41
 
@@ -61,6 +66,39 @@ def get_instruments(text_sequence: str) -> List[str]:
61
  return instruments
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def generate_new_instrument(seed: str, temp: float = 0.75) -> str:
65
  """
66
  Generates a new instrument sequence from a given seed and temperature.
@@ -165,6 +203,13 @@ def remove_last_instrument(
165
  return audio, midi_file, fig, instruments_str, new_song, num_tokens
166
 
167
 
 
 
 
 
 
 
 
168
  def regenerate_last_instrument(
169
  text_sequence: str, qpm: int = 120
170
  ) -> Tuple[ndarray, str, Figure, str, str, str]:
@@ -179,19 +224,32 @@ def regenerate_last_instrument(
179
  Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
180
  instruments string, new song string, and number of tokens string.
181
  """
182
- last_inst_index = text_sequence.rfind("INST=")
183
- if last_inst_index == -1:
184
- # No instrument so start from empty sequence
185
- audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
186
- text_sequence="", qpm=qpm
187
- )
 
 
 
 
 
 
 
 
 
 
 
 
188
  else:
189
- # Take it from the last instrument and continue generation
190
- next_space_index = text_sequence.find(" ", last_inst_index)
191
- new_seed = text_sequence[:next_space_index]
192
- audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
193
- text_sequence=new_seed, qpm=qpm
194
- )
 
195
  return audio, midi_file, fig, instruments_str, new_song, num_tokens
196
 
197
 
@@ -218,9 +276,10 @@ def change_tempo(
218
  def generate_song(
219
  genre: str = "OTHER",
220
  artist: str = "KATE_BUSH",
 
221
  temp: float = 0.75,
222
  text_sequence: str = "",
223
- qpm: int = 120,
224
  ) -> Tuple[ndarray, str, Figure, str, str, str]:
225
  """
226
  Generates a song given a genre, temperature, initial text sequence, and tempo.
@@ -238,13 +297,21 @@ def generate_song(
238
  Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
239
  instruments string, generated song string, and number of tokens string.
240
  """
 
 
 
 
 
 
 
241
  if text_sequence == "":
242
- seed_string = create_seed_string(genre, artist)
243
  else:
244
- seed_string = text_sequence
245
 
246
  generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
247
  audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
248
  generated_sequence, qpm
249
  )
250
  return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens
 
 
11
  from string_to_notes import token_sequence_to_note_sequence
12
  from model import get_model_and_tokenizer
13
 
14
+ import json
15
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  # Load the tokenizer and the model
19
  model, tokenizer = get_model_and_tokenizer()
20
 
21
+ # Instruments
22
+ with open('instruments.json', 'r') as f:
23
+ instruments = json.load(f)
24
 
25
+
26
+ def create_seed_string(genre: str = "OTHER", artist: str = "OTHER", instrument:str="0") -> str:
27
  """
28
  Creates a seed string for generating a new piece.
29
 
 
34
  str: The seed string.
35
  """
36
  if genre == "RANDOM" and artist == "RANDOM":
37
+ seed_string = f"PIECE_START GENRE=RANDOM ARTIST=RANDOM TRACK_START INST={instrument}"
38
  elif genre == "RANDOM" and artist != "RANDOM":
39
+ seed_string = f"PIECE_START GENRE=RANDOM ARTIST={artist} TRACK_START INST={instrument}"
40
  elif genre != "RANDOM" and artist == "RANDOM":
41
+ seed_string = f"PIECE_START GENRE={genre} ARTIST=RANDOM TRACK_START INST={instrument}"
42
  else:
43
+ seed_string = f"PIECE_START GENRE={genre} ARTIST={artist} TRACK_START INST={instrument}"
44
  return seed_string
45
 
46
 
 
66
  return instruments
67
 
68
 
69
+ def change_last_instrument( text_sequence: str,
70
+ instrument: str,
71
+ temp: float = 0.75,
72
+ qpm: int = 120
73
+ ) -> Tuple[ndarray, str, Figure, str, str, str]:
74
+
75
+
76
+ instrument_idx = instruments.index(instrument)
77
+ #Drums
78
+ if instrument_idx == 0:
79
+ instrument_idx='DRUMS'
80
+ else:
81
+ instrument_idx = str(instrument_idx-1)
82
+ text_sequence = text_sequence.split()
83
+ for token_idx in reversed(range(len(text_sequence))):
84
+ if "INST=" in text_sequence[token_idx]:
85
+ text_sequence[token_idx] = f"INST={instrument_idx}"
86
+ break
87
+ text_sequence = (' ').join(text_sequence)
88
+ #print(text_sequence)
89
+
90
+ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
91
+ text_sequence, qpm
92
+ )
93
+ # print(type(audio),audio)
94
+ # print(type(midi_file),midi_file)
95
+ # print(type(fig),fig)
96
+ # print(type(instruments_str),instruments_str)
97
+ # print(type(num_tokens),num_tokens)
98
+ return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
99
+
100
+
101
+
102
  def generate_new_instrument(seed: str, temp: float = 0.75) -> str:
103
  """
104
  Generates a new instrument sequence from a given seed and temperature.
 
203
  return audio, midi_file, fig, instruments_str, new_song, num_tokens
204
 
205
 
206
+ genre: str = "OTHER",
207
+ artist: str = "KATE_BUSH",
208
+ instrument: str = "Acoustic Grand Piano",
209
+ temp: float = 0.75,
210
+ text_sequence: str = "",
211
+ qpm: int = 120
212
+
213
  def regenerate_last_instrument(
214
  text_sequence: str, qpm: int = 120
215
  ) -> Tuple[ndarray, str, Figure, str, str, str]:
 
224
  Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
225
  instruments string, new song string, and number of tokens string.
226
  """
227
+
228
+ def remove_last_track(text_sequence):
229
+ tracks = text_sequence.split("TRACK_START")
230
+ # We keep all tracks except the last one
231
+ useful_tracks = tracks[:-1]
232
+ # We join the tracks back together, adding back the 'TRACK_START' that was removed by split
233
+ text_sequence = "TRACK_START".join(useful_tracks)
234
+ return text_sequence
235
+
236
+ #last_inst_index = text_sequence.rfind("INST=")
237
+
238
+ for token in reversed(text_sequence.split()):
239
+ if 'INST=' in token:
240
+ instrument_id = token.split('=')[1]
241
+ break
242
+
243
+ if instrument_id=="DRUMS":
244
+ instrument="Drums"
245
  else:
246
+ instrument=instruments[int(instrument_id)+1]# Index 0 instrument is 'Acoustic Grand Piano' for rendering:https://soundprogramming.net/file-formats/general-midi-instrument-list/#google_vignette
247
+
248
+ new_seed = remove_last_track(text_sequence=text_sequence)
249
+
250
+ audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
251
+ instrument=instrument,text_sequence=new_seed, qpm=qpm
252
+ )
253
  return audio, midi_file, fig, instruments_str, new_song, num_tokens
254
 
255
 
 
276
  def generate_song(
277
  genre: str = "OTHER",
278
  artist: str = "KATE_BUSH",
279
+ instrument: str = "Acoustic Grand Piano",
280
  temp: float = 0.75,
281
  text_sequence: str = "",
282
+ qpm: int = 120
283
  ) -> Tuple[ndarray, str, Figure, str, str, str]:
284
  """
285
  Generates a song given a genre, temperature, initial text sequence, and tempo.
 
297
  Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
298
  instruments string, generated song string, and number of tokens string.
299
  """
300
+ instrument = instruments.index(instrument)
301
+ #Drums
302
+ if instrument == 0:
303
+ instrument='DRUMS'
304
+ else:
305
+ instrument = str(instrument-1)
306
+
307
  if text_sequence == "":
308
+ seed_string = create_seed_string(genre, artist, instrument)
309
  else:
310
+ seed_string = text_sequence + " TRACK_START INST=" + instrument
311
 
312
  generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
313
  audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
314
  generated_sequence, qpm
315
  )
316
  return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens
317
+