the-jam-machine-app / generate.py
matthew mitton
Duplicate from JammyMachina/the-jam-machine-app
9118de8
raw
history blame
18.3 kB
from generation_utils import *
from utils import WriteTextMidiToFile, get_miditok
from load import LoadModel
from decoder import TextDecoder
from playback import get_music
class GenerateMidiText:
"""Generating music with Class
LOGIC:
FOR GENERATING FROM SCRATCH:
- self.generate_one_new_track()
it calls
- self.generate_until_track_end()
FOR GENERATING NEW BARS:
- self.generate_one_more_bar()
it calls
- self.process_prompt_for_next_bar()
- self.generate_until_track_end()"""
def __init__(self, model, tokenizer, piece_by_track=[]):
self.model = model
self.tokenizer = tokenizer
# default initialization
self.initialize_default_parameters()
self.initialize_dictionaries(piece_by_track)
"""Setters"""
def initialize_default_parameters(self):
self.set_device()
self.set_attention_length()
self.generate_until = "TRACK_END"
self.set_force_sequence_lenth()
self.set_nb_bars_generated()
self.set_improvisation_level(0)
def initialize_dictionaries(self, piece_by_track):
self.piece_by_track = piece_by_track
def set_device(self, device="cpu"):
self.device = ("cpu",)
def set_attention_length(self):
self.max_length = self.model.config.n_positions
print(
f"Attention length set to {self.max_length} -> 'model.config.n_positions'"
)
def set_force_sequence_lenth(self, force_sequence_length=True):
self.force_sequence_length = force_sequence_length
def set_improvisation_level(self, improvisation_value):
self.no_repeat_ngram_size = improvisation_value
print("--------------------")
print(f"no_repeat_ngram_size set to {improvisation_value}")
print("--------------------")
def reset_temperatures(self, track_id, temperature):
self.piece_by_track[track_id]["temperature"] = temperature
def set_nb_bars_generated(self, n_bars=8): # default is a 8 bar model
self.model_n_bar = n_bars
""" Generation Tools - Dictionnaries """
def initiate_track_dict(self, instr, density, temperature):
label = len(self.piece_by_track)
self.piece_by_track.append(
{
"label": f"track_{label}",
"instrument": instr,
"density": density,
"temperature": temperature,
"bars": [],
}
)
def update_track_dict__add_bars(self, bars, track_id):
"""Add bars to the track dictionnary"""
for bar in self.striping_track_ends(bars).split("BAR_START "):
if bar == "": # happens is there is one bar only
continue
else:
if "TRACK_START" in bar:
self.piece_by_track[track_id]["bars"].append(bar)
else:
self.piece_by_track[track_id]["bars"].append("BAR_START " + bar)
def get_all_instr_bars(self, track_id):
return self.piece_by_track[track_id]["bars"]
def striping_track_ends(self, text):
if "TRACK_END" in text:
# first get rid of extra space if any
# then gets rid of "TRACK_END"
text = text.rstrip(" ").rstrip("TRACK_END")
return text
def get_last_generated_track(self, full_piece):
track = (
"TRACK_START "
+ self.striping_track_ends(full_piece.split("TRACK_START ")[-1])
+ "TRACK_END "
) # forcing the space after track and
return track
def get_selected_track_as_text(self, track_id):
text = ""
for bar in self.piece_by_track[track_id]["bars"]:
text += bar
text += "TRACK_END "
return text
@staticmethod
def get_newly_generated_text(input_prompt, full_piece):
return full_piece[len(input_prompt) :]
def get_whole_piece_from_bar_dict(self):
text = "PIECE_START "
for track_id, _ in enumerate(self.piece_by_track):
text += self.get_selected_track_as_text(track_id)
return text
def delete_one_track(self, track): # TO BE TESTED
self.piece_by_track.pop(track)
# def update_piece_dict__add_track(self, track_id, track):
# self.piece_dict[track_id] = track
# def update_all_dictionnaries__add_track(self, track):
# self.update_piece_dict__add_track(track_id, track)
"""Basic generation tools"""
def tokenize_input_prompt(self, input_prompt, verbose=True):
"""Tokenizing prompt
Args:
- input_prompt (str): prompt to tokenize
Returns:
- input_prompt_ids (torch.tensor): tokenized prompt
"""
if verbose:
print("Tokenizing input_prompt...")
return self.tokenizer.encode(input_prompt, return_tensors="pt")
def generate_sequence_of_token_ids(
self,
input_prompt_ids,
temperature,
verbose=True,
):
"""
generate a sequence of token ids based on input_prompt_ids
The sequence length depends on the trained model (self.model_n_bar)
"""
generated_ids = self.model.generate(
input_prompt_ids,
max_length=self.max_length,
do_sample=True,
temperature=temperature,
no_repeat_ngram_size=self.no_repeat_ngram_size, # default = 0
eos_token_id=self.tokenizer.encode(self.generate_until)[0], # good
)
if verbose:
print("Generating a token_id sequence...")
return generated_ids
def convert_ids_to_text(self, generated_ids, verbose=True):
"""converts the token_ids to text"""
generated_text = self.tokenizer.decode(generated_ids[0])
if verbose:
print("Converting token sequence to MidiText...")
return generated_text
def generate_until_track_end(
self,
input_prompt="PIECE_START ",
instrument=None,
density=None,
temperature=None,
verbose=True,
expected_length=None,
):
"""generate until the TRACK_END token is reached
full_piece = input_prompt + generated"""
if expected_length is None:
expected_length = self.model_n_bar
if instrument is not None:
input_prompt = f"{input_prompt}TRACK_START INST={str(instrument)} "
if density is not None:
input_prompt = f"{input_prompt}DENSITY={str(density)} "
if instrument is None and density is not None:
print("Density cannot be defined without an input_prompt instrument #TOFIX")
if temperature is None:
ValueError("Temperature must be defined")
if verbose:
print("--------------------")
print(
f"Generating {instrument} - Density {density} - temperature {temperature}"
)
bar_count_checks = False
failed = 0
while not bar_count_checks: # regenerate until right length
input_prompt_ids = self.tokenize_input_prompt(input_prompt, verbose=verbose)
generated_tokens = self.generate_sequence_of_token_ids(
input_prompt_ids, temperature, verbose=verbose
)
full_piece = self.convert_ids_to_text(generated_tokens, verbose=verbose)
generated = self.get_newly_generated_text(input_prompt, full_piece)
# bar_count_checks
bar_count_checks, bar_count = bar_count_check(generated, expected_length)
if not self.force_sequence_length:
# set bar_count_checks to true to exist the while loop
bar_count_checks = True
if not bar_count_checks and self.force_sequence_length:
# if the generated sequence is not the expected length
if failed > -1: # deactivated for speed
full_piece, bar_count_checks = forcing_bar_count(
input_prompt,
generated,
bar_count,
expected_length,
)
else:
print('"--- Wrong length - Regenerating ---')
if not bar_count_checks:
failed += 1
if failed > 2:
bar_count_checks = True # TOFIX exit the while loop
return full_piece
def generate_one_new_track(
self,
instrument,
density,
temperature,
input_prompt="PIECE_START ",
):
self.initiate_track_dict(instrument, density, temperature)
full_piece = self.generate_until_track_end(
input_prompt=input_prompt,
instrument=instrument,
density=density,
temperature=temperature,
)
track = self.get_last_generated_track(full_piece)
self.update_track_dict__add_bars(track, -1)
full_piece = self.get_whole_piece_from_bar_dict()
return full_piece
""" Piece generation - Basics """
def generate_piece(self, instrument_list, density_list, temperature_list):
"""generate a sequence with mutiple tracks
- inst_list sets the list of instruments of the order of generation
- density is paired with inst_list
Each track/intrument is generated on a prompt which contains the previously generated track/instrument
This means that the first instrument is generated with less bias than the next one, and so on.
'generated_piece' keeps track of the entire piece
'generated_piece' is returned by self.generate_until_track_end
# it is returned by self.generate_until_track_end"""
generated_piece = "PIECE_START "
for instrument, density, temperature in zip(
instrument_list, density_list, temperature_list
):
generated_piece = self.generate_one_new_track(
instrument,
density,
temperature,
input_prompt=generated_piece,
)
# generated_piece = self.get_whole_piece_from_bar_dict()
self.check_the_piece_for_errors()
return generated_piece
""" Piece generation - Extra Bars """
@staticmethod
def process_prompt_for_next_bar(self, track_idx):
"""Processing the prompt for the model to generate one more bar only.
The prompt containts:
if not the first bar: the previous, already processed, bars of the track
the bar initialization (ex: "TRACK_START INST=DRUMS DENSITY=2 ")
the last (self.model_n_bar)-1 bars of the track
Args:
track_idx (int): the index of the track to be processed
Returns:
the processed prompt for generating the next bar
"""
track = self.piece_by_track[track_idx]
# for bars which are not the bar to prolong
pre_promt = "PIECE_START "
for i, othertrack in enumerate(self.piece_by_track):
if i != track_idx:
len_diff = len(othertrack["bars"]) - len(track["bars"])
if len_diff > 0:
# if other bars are longer, it mean that this one should catch up
pre_promt += othertrack["bars"][0]
for bar in track["bars"][-self.model_n_bar :]:
pre_promt += bar
pre_promt += "TRACK_END "
elif False: # len_diff <= 0: # THIS GENERATES EMPTINESS
# adding an empty bars at the end of the other tracks if they have not been processed yet
pre_promt += othertracks["bars"][0]
for bar in track["bars"][-(self.model_n_bar - 1) :]:
pre_promt += bar
for _ in range(abs(len_diff) + 1):
pre_promt += "BAR_START BAR_END "
pre_promt += "TRACK_END "
# for the bar to prolong
# initialization e.g TRACK_START INST=DRUMS DENSITY=2
processed_prompt = track["bars"][0]
for bar in track["bars"][-(self.model_n_bar - 1) :]:
# adding the "last" bars of the track
processed_prompt += bar
processed_prompt += "BAR_START "
print(
f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---"
)
return pre_promt + processed_prompt
def generate_one_more_bar(self, i):
"""Generate one more bar from the input_prompt"""
processed_prompt = self.process_prompt_for_next_bar(self, i)
prompt_plus_bar = self.generate_until_track_end(
input_prompt=processed_prompt,
temperature=self.piece_by_track[i]["temperature"],
expected_length=1,
verbose=False,
)
added_bar = self.get_newly_generated_bar(prompt_plus_bar)
self.update_track_dict__add_bars(added_bar, i)
def get_newly_generated_bar(self, prompt_plus_bar):
return "BAR_START " + self.striping_track_ends(
prompt_plus_bar.split("BAR_START ")[-1]
)
def generate_n_more_bars(self, n_bars, only_this_track=None, verbose=True):
"""Generate n more bars from the input_prompt"""
if only_this_track is None:
only_this_track
print(f"================== ")
print(f"Adding {n_bars} more bars to the piece ")
for bar_id in range(n_bars):
print(f"----- added bar #{bar_id+1} --")
for i, track in enumerate(self.piece_by_track):
if only_this_track is None or i == only_this_track:
print(f"--------- {track['label']}")
self.generate_one_more_bar(i)
self.check_the_piece_for_errors()
def check_the_piece_for_errors(self, piece: str = None):
if piece is None:
piece = generate_midi.get_whole_piece_from_bar_dict()
errors = []
errors.append(
[
(token, id)
for id, token in enumerate(piece.split(" "))
if token not in self.tokenizer.vocab or token == "UNK"
]
)
if len(errors) > 0:
# print(piece)
for er in errors:
er
print(f"Token not found in the piece at {er[0][1]}: {er[0][0]}")
print(piece.split(" ")[er[0][1] - 5 : er[0][1] + 5])
if __name__ == "__main__":
# worker
DEVICE = "cpu"
# define generation parameters
N_FILES_TO_GENERATE = 2
Temperatures_to_try = [0.7]
USE_FAMILIZED_MODEL = True
force_sequence_length = True
if USE_FAMILIZED_MODEL:
# model_repo = "misnaej/the-jam-machine-elec-famil"
# model_repo = "misnaej/the-jam-machine-elec-famil-ft32"
# model_repo = "JammyMachina/elec-gmusic-familized-model-13-12__17-35-53"
# n_bar_generated = 8
model_repo = "JammyMachina/improved_4bars-mdl"
n_bar_generated = 4
instrument_promt_list = ["4", "DRUMS", "3"]
# DRUMS = drums, 0 = piano, 1 = chromatic percussion, 2 = organ, 3 = guitar, 4 = bass, 5 = strings, 6 = ensemble, 7 = brass, 8 = reed, 9 = pipe, 10 = synth lead, 11 = synth pad, 12 = synth effects, 13 = ethnic, 14 = percussive, 15 = sound effects
density_list = [3, 2, 2]
# temperature_list = [0.7, 0.7, 0.75]
else:
model_repo = "misnaej/the-jam-machine"
instrument_promt_list = ["30"] # , "DRUMS", "0"]
density_list = [3] # , 2, 3]
# temperature_list = [0.7, 0.5, 0.75]
pass
# define generation directory
generated_sequence_files_path = define_generation_dir(model_repo)
# load model and tokenizer
model, tokenizer = LoadModel(
model_repo, from_huggingface=True
).load_model_and_tokenizer()
# does the prompt make sense
check_if_prompt_inst_in_tokenizer_vocab(tokenizer, instrument_promt_list)
for temperature in Temperatures_to_try:
print(f"================= TEMPERATURE {temperature} =======================")
for _ in range(N_FILES_TO_GENERATE):
print(f"========================================")
# 1 - instantiate
generate_midi = GenerateMidiText(model, tokenizer)
# 0 - set the n_bar for this model
generate_midi.set_nb_bars_generated(n_bars=n_bar_generated)
# 1 - defines the instruments, densities and temperatures
# 2- generate the first 8 bars for each instrument
generate_midi.set_improvisation_level(30)
generate_midi.generate_piece(
instrument_promt_list,
density_list,
[temperature for _ in density_list],
)
# 3 - force the model to improvise
# generate_midi.set_improvisation_level(20)
# 4 - generate the next 4 bars for each instrument
# generate_midi.generate_n_more_bars(n_bar_generated)
# 5 - lower the improvisation level
generate_midi.generated_piece = (
generate_midi.get_whole_piece_from_bar_dict()
)
# print the generated sequence in terminal
print("=========================================")
print(generate_midi.generated_piece)
print("=========================================")
# write to JSON file
filename = WriteTextMidiToFile(
generate_midi,
generated_sequence_files_path,
).text_midi_to_file()
# decode the sequence to MIDI """
decode_tokenizer = get_miditok()
TextDecoder(decode_tokenizer, USE_FAMILIZED_MODEL).get_midi(
generate_midi.generated_piece, filename=filename.split(".")[0] + ".mid"
)
inst_midi, mixed_audio = get_music(filename.split(".")[0] + ".mid")
max_time = get_max_time(inst_midi)
plot_piano_roll(inst_midi)
print("Et voilà! Your MIDI file is ready! GO JAM!")