from generation_utils import * import random class GenerateMidiText: """Generating music with Class LOGIC: FOR GENERATING FROM SCRATCH: - self.generate_one_new_track() it calls - self.generate_until_track_end() FOR GENERATING NEW BARS: - self.generate_one_more_bar() it calls - self.process_prompt_for_next_bar() - self.generate_until_track_end()""" def __init__(self, model, tokenizer, piece_by_track=[]): self.model = model self.tokenizer = tokenizer # default initialization self.initialize_default_parameters() self.initialize_dictionaries(piece_by_track) """Setters""" def initialize_default_parameters(self): self.set_device() self.set_attention_length() self.generate_until = "TRACK_END" self.set_force_sequence_lenth() self.set_nb_bars_generated() self.set_improvisation_level(0) def initialize_dictionaries(self, piece_by_track): self.piece_by_track = piece_by_track def set_device(self, device="cpu"): self.device = ("cpu",) def set_attention_length(self): self.max_length = self.model.config.n_positions print( f"Attention length set to {self.max_length} -> 'model.config.n_positions'" ) def set_force_sequence_lenth(self, force_sequence_length=True): self.force_sequence_length = force_sequence_length def set_improvisation_level(self, improvisation_value): self.no_repeat_ngram_size = improvisation_value print("--------------------") print(f"no_repeat_ngram_size set to {improvisation_value}") print("--------------------") def reset_temperatures(self, track_id, temperature): self.piece_by_track[track_id]["temperature"] = temperature def set_nb_bars_generated(self, n_bars=8): # default is a 8 bar model self.model_n_bar = n_bars """ Generation Tools - Dictionnaries """ def initiate_track_dict(self, instr, density, temperature): label = len(self.piece_by_track) self.piece_by_track.append( { "label": f"track_{label}", "instrument": instr, "density": density, "temperature": temperature, "bars": [], } ) def update_track_dict__add_bars(self, bars, track_id): """Add bars to the track dictionnary""" for bar in self.striping_track_ends(bars).split("BAR_START "): if bar == "": # happens is there is one bar only continue else: if "TRACK_START" in bar: self.piece_by_track[track_id]["bars"].append(bar) else: self.piece_by_track[track_id]["bars"].append("BAR_START " + bar) def get_all_instr_bars(self, track_id): return self.piece_by_track[track_id]["bars"] def striping_track_ends(self, text): if "TRACK_END" in text: # first get rid of extra space if any # then gets rid of "TRACK_END" text = text.rstrip(" ").rstrip("TRACK_END") return text def get_last_generated_track(self, piece): """Get the last track from a piece written as a single long string""" track = self.get_tracks_from_a_piece(piece)[-1] return track def get_tracks_from_a_piece(self, piece): """Get all the tracks from a piece written as a single long string""" all_tracks = [ "TRACK_START " + the_track + "TRACK_END " for the_track in self.striping_track_ends(piece.split("TRACK_START ")[1::]) ] return all_tracks def get_piece_from_track_list(self, track_list): piece = "PIECE_START " for track in track_list: piece += track return piece def get_whole_track_from_bar_dict(self, track_id): text = "" for bar in self.piece_by_track[track_id]["bars"]: text += bar text += "TRACK_END " return text @staticmethod def get_newly_generated_text(input_prompt, full_piece): return full_piece[len(input_prompt) :] def get_whole_piece_from_bar_dict(self): text = "PIECE_START " for track_id, _ in enumerate(self.piece_by_track): text += self.get_whole_track_from_bar_dict(track_id) return text def delete_one_track(self, track): self.piece_by_track.pop(track) """Basic generation tools""" def tokenize_input_prompt(self, input_prompt, verbose=True): """Tokenizing prompt Args: - input_prompt (str): prompt to tokenize Returns: - input_prompt_ids (torch.tensor): tokenized prompt """ if verbose: print("Tokenizing input_prompt...") return self.tokenizer.encode(input_prompt, return_tensors="pt") def generate_sequence_of_token_ids( self, input_prompt_ids, temperature, verbose=True, ): """ generate a sequence of token ids based on input_prompt_ids The sequence length depends on the trained model (self.model_n_bar) """ generated_ids = self.model.generate( input_prompt_ids, max_length=self.max_length, do_sample=True, temperature=temperature, no_repeat_ngram_size=self.no_repeat_ngram_size, # default = 0 eos_token_id=self.tokenizer.encode(self.generate_until)[0], # good ) if verbose: print("Generating a token_id sequence...") return generated_ids def convert_ids_to_text(self, generated_ids, verbose=True): """converts the token_ids to text""" generated_text = self.tokenizer.decode(generated_ids[0]) if verbose: print("Converting token sequence to MidiText...") return generated_text def generate_until_track_end( self, input_prompt="PIECE_START ", instrument=None, density=None, temperature=None, verbose=True, expected_length=None, ): """generate until the TRACK_END token is reached full_piece = input_prompt + generated""" if expected_length is None: expected_length = self.model_n_bar if instrument is not None: input_prompt = f"{input_prompt}TRACK_START INST={str(instrument)} " if density is not None: input_prompt = f"{input_prompt}DENSITY={str(density)} " if instrument is None and density is not None: print("Density cannot be defined without an input_prompt instrument #TOFIX") if temperature is None: ValueError("Temperature must be defined") if verbose: print("--------------------") print( f"Generating {instrument} - Density {density} - temperature {temperature}" ) bar_count_checks = False failed = 0 while not bar_count_checks: # regenerate until right length input_prompt_ids = self.tokenize_input_prompt(input_prompt, verbose=verbose) generated_tokens = self.generate_sequence_of_token_ids( input_prompt_ids, temperature, verbose=verbose ) full_piece = self.convert_ids_to_text(generated_tokens, verbose=verbose) generated = self.get_newly_generated_text(input_prompt, full_piece) # bar_count_checks bar_count_checks, bar_count = bar_count_check(generated, expected_length) if not self.force_sequence_length: # set bar_count_checks to true to exist the while loop bar_count_checks = True if not bar_count_checks and self.force_sequence_length: # if the generated sequence is not the expected length if failed > -1: # deactivated for speed full_piece, bar_count_checks = forcing_bar_count( input_prompt, generated, bar_count, expected_length, ) else: print('"--- Wrong length - Regenerating ---') if not bar_count_checks: failed += 1 if failed > 2: bar_count_checks = True # exit the while loop if failed too much return full_piece def generate_one_new_track( self, instrument, density, temperature, input_prompt="PIECE_START ", ): self.initiate_track_dict(instrument, density, temperature) full_piece = self.generate_until_track_end( input_prompt=input_prompt, instrument=instrument, density=density, temperature=temperature, ) track = self.get_last_generated_track(full_piece) self.update_track_dict__add_bars(track, -1) full_piece = self.get_whole_piece_from_bar_dict() return full_piece """ Piece generation - Basics """ def generate_piece(self, instrument_list, density_list, temperature_list): """generate a sequence with mutiple tracks Args: - inst_list sets the list of instruments and the the order of generation - density and - temperature are paired with inst_list Each track/intrument is generated based on a prompt which contains the previously generated track/instrument Returns: 'generated_piece' which keeps track of the entire piece """ generated_piece = "PIECE_START " for instrument, density, temperature in zip( instrument_list, density_list, temperature_list ): generated_piece = self.generate_one_new_track( instrument, density, temperature, input_prompt=generated_piece, ) # generated_piece = self.get_whole_piece_from_bar_dict() self.check_the_piece_for_errors() return generated_piece """ Piece generation - Extra Bars """ def process_prompt_for_next_bar(self, track_idx, verbose=True): """Processing the prompt for the model to generate one more bar only. The prompt containts: if not the first bar: the previous, already processed, bars of the track the bar initialization (ex: "TRACK_START INST=DRUMS DENSITY=2 ") the last (self.model_n_bar)-1 bars of the track Args: track_idx (int): the index of the track to be processed Returns: the processed prompt for generating the next bar """ track = self.piece_by_track[track_idx] # for bars which are not the bar to prolong pre_promt = "PIECE_START " for i, othertrack in enumerate(self.piece_by_track): if i != track_idx: len_diff = len(othertrack["bars"]) - len(track["bars"]) if len_diff > 0: if verbose: print( f"Adding bars - {len(track['bars'][-self.model_n_bar :])} selected from SIDE track: {i} for prompt" ) # if other bars are longer, it mean that this one should catch up pre_promt += othertrack["bars"][0] for bar in track["bars"][-self.model_n_bar :]: pre_promt += bar pre_promt += "TRACK_END " elif ( False ): # len_diff <= 0: # THIS DOES NOT WORK - It just adds empty bars # adding an empty bars at the end of the other tracks if they have not been processed yet pre_promt += othertracks["bars"][0] for bar in track["bars"][-(self.model_n_bar - 1) :]: pre_promt += bar for _ in range(abs(len_diff) + 1): pre_promt += "BAR_START BAR_END " pre_promt += "TRACK_END " # for the bar to prolong # initialization e.g TRACK_START INST=DRUMS DENSITY=2 processed_prompt = track["bars"][0] if verbose: print( f"Adding bars - {len(track['bars'][-(self.model_n_bar - 1) :])} selected from MAIN track: {track_idx} for prompt" ) for bar in track["bars"][-(self.model_n_bar - 1) :]: # adding the "last" bars of the track processed_prompt += bar processed_prompt += "BAR_START " # making the preprompt short enought to avoid bug due to length of the prompt (model limitation) pre_promt = self.force_prompt_length(pre_promt, 1500) print( f"--- prompt length = {len((pre_promt + processed_prompt).split(' '))} ---" ) return pre_promt + processed_prompt def force_prompt_length(self, prompt, expected_length): """remove one instrument/track from the prompt it too long Args: prompt (str): the prompt to be processed expected_length (int): the expected length of the prompt Returns: the truncated prompt""" if len(prompt.split(" ")) < expected_length: truncated_prompt = prompt else: tracks = self.get_tracks_from_a_piece(prompt) selected_tracks = random.sample(tracks, len(tracks) - 1) truncated_prompt = self.get_piece_from_track_list(selected_tracks) print(f"Prompt too long - deleting one track") return truncated_prompt def generate_one_more_bar(self, track_index): """Generate one more bar from the input_prompt""" processed_prompt = self.process_prompt_for_next_bar(track_index) prompt_plus_bar = self.generate_until_track_end( input_prompt=processed_prompt, temperature=self.piece_by_track[track_index]["temperature"], expected_length=1, verbose=False, ) added_bar = self.get_newly_generated_bar(prompt_plus_bar) self.update_track_dict__add_bars(added_bar, track_index) def get_newly_generated_bar(self, prompt_plus_bar): return "BAR_START " + self.striping_track_ends( prompt_plus_bar.split("BAR_START ")[-1] ) def generate_n_more_bars(self, n_bars, only_this_track=None, verbose=True): """Generate n more bars from the input_prompt""" if only_this_track is None: only_this_track print(f"================== ") print(f"Adding {n_bars} more bars to the piece ") for bar_id in range(n_bars): print(f"----- added bar #{bar_id+1} --") for i, track in enumerate(self.piece_by_track): if only_this_track is None or i == only_this_track: print(f"--------- {track['label']}") self.generate_one_more_bar(i) self.check_the_piece_for_errors() def check_the_piece_for_errors(self, piece: str = None): if piece is None: piece = self.get_whole_piece_from_bar_dict() errors = [] errors.append( [ (token, id) for id, token in enumerate(piece.split(" ")) if token not in self.tokenizer.vocab or token == "UNK" ] ) if len(errors) > 0: # print(piece) for er in errors: er print(f"Token not found in the piece at {er[0][1]}: {er[0][0]}") print(piece.split(" ")[er[0][1] - 5 : er[0][1] + 5]) if __name__ == "__main__": pass