from typing import List, Tuple import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import note_seq from matplotlib.figure import Figure from numpy import ndarray import torch from constants import GM_INSTRUMENTS, SAMPLE_RATE from string_to_notes import token_sequence_to_note_sequence from model import get_model_and_tokenizer import json device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the tokenizer and the model model, tokenizer = get_model_and_tokenizer() # Instruments with open('instruments.json', 'r') as f: instruments = json.load(f) def create_seed_string(genre: str = "OTHER", artist: str = "OTHER", instrument:str="0") -> str: """ Creates a seed string for generating a new piece. Args: genre (str, optional): The genre of the piece. Defaults to "OTHER". Returns: str: The seed string. """ if genre == "RANDOM" and artist == "RANDOM": seed_string = f"PIECE_START GENRE=RANDOM ARTIST=RANDOM TRACK_START INST={instrument}" elif genre == "RANDOM" and artist != "RANDOM": seed_string = f"PIECE_START GENRE=RANDOM ARTIST={artist} TRACK_START INST={instrument}" elif genre != "RANDOM" and artist == "RANDOM": seed_string = f"PIECE_START GENRE={genre} ARTIST=RANDOM TRACK_START INST={instrument}" else: seed_string = f"PIECE_START GENRE={genre} ARTIST={artist} TRACK_START INST={instrument}" return seed_string def get_instruments(text_sequence: str) -> List[str]: """ Extracts the list of instruments from a text sequence. Args: text_sequence (str): The text sequence. Returns: List[str]: The list of instruments. """ instruments = [] parts = text_sequence.split() for part in parts: if part.startswith("INST="): if part[5:] == "DRUMS": instruments.append("Drums") else: index = int(part[5:]) instruments.append(GM_INSTRUMENTS[index]) return instruments def change_last_instrument( text_sequence: str, instrument: str, temp: float = 0.75, qpm: int = 120 ) -> Tuple[ndarray, str, Figure, str, str, str]: instrument_idx = instruments.index(instrument) #Drums if instrument_idx == 0: instrument_idx='DRUMS' else: instrument_idx = str(instrument_idx-1) text_sequence = text_sequence.split() for token_idx in reversed(range(len(text_sequence))): if "INST=" in text_sequence[token_idx]: text_sequence[token_idx] = f"INST={instrument_idx}" break text_sequence = (' ').join(text_sequence) #print(text_sequence) audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( text_sequence, qpm ) # print(type(audio),audio) # print(type(midi_file),midi_file) # print(type(fig),fig) # print(type(instruments_str),instruments_str) # print(type(num_tokens),num_tokens) return audio, midi_file, fig, instruments_str, text_sequence, num_tokens def generate_new_instrument(seed: str, temp: float = 0.75) -> str: """ Generates a new instrument sequence from a given seed and temperature. Args: seed (str): The seed string for the generation. temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75. Returns: str: The generated instrument sequence. """ seed_length = len(tokenizer.encode(seed)) while True: # Encode the conditioning tokens. input_ids = tokenizer.encode(seed, return_tensors="pt") # Move the input_ids tensor to the same device as the model input_ids = input_ids.to(model.device) # Generate more tokens. eos_token_id = tokenizer.encode("TRACK_END")[0] generated_ids = model.generate( input_ids, max_new_tokens=2048, do_sample=True, temperature=temp, eos_token_id=eos_token_id, ) generated_sequence = tokenizer.decode(generated_ids[0]) # Check if the generated sequence contains "NOTE_ON" beyond the seed new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:]) if "NOTE_ON" in new_generated_sequence: return generated_sequence def get_outputs_from_string( generated_sequence: str, qpm: int = 120 ) -> Tuple[ndarray, str, Figure, str, str]: """ Converts a generated sequence into various output formats including audio, MIDI, plot, etc. Args: generated_sequence (str): The generated sequence of tokens. qpm (int, optional): The quarter notes per minute. Defaults to 120. Returns: Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure, instruments string, and number of tokens string. """ instruments = get_instruments(generated_sequence) instruments_str = "\n".join(f"- {instrument}" for instrument in instruments) note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm) synth = note_seq.fluidsynth array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE) int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats) fig = note_seq.plot_sequence(note_sequence, show_figure=False) num_tokens = str(len(generated_sequence.split())) audio = gr.make_waveform((SAMPLE_RATE, int16_data)) note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid") return audio, "midi_ouput.mid", fig, instruments_str, num_tokens def remove_last_instrument( text_sequence: str, qpm: int = 120 ) -> Tuple[ndarray, str, Figure, str, str, str]: """ Removes the last instrument from a song string and returns the various output formats. Args: text_sequence (str): The song string. qpm (int, optional): The quarter notes per minute. Defaults to 120. Returns: Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, instruments string, new song string, and number of tokens string. """ # We split the song into tracks by splitting on 'TRACK_START' tracks = text_sequence.split("TRACK_START") # We keep all tracks except the last one modified_tracks = tracks[:-1] # We join the tracks back together, adding back the 'TRACK_START' that was removed by split new_song = "TRACK_START".join(modified_tracks) if len(tracks) == 2: # There is only one instrument, so start from scratch audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( text_sequence=new_song ) elif len(tracks) == 1: # No instrument so start from empty sequence audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( text_sequence="" ) else: audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( new_song, qpm ) return audio, midi_file, fig, instruments_str, new_song, num_tokens genre: str = "OTHER", artist: str = "KATE_BUSH", instrument: str = "Acoustic Grand Piano", temp: float = 0.75, text_sequence: str = "", qpm: int = 120 def regenerate_last_instrument( text_sequence: str, qpm: int = 120 ) -> Tuple[ndarray, str, Figure, str, str, str]: """ Regenerates the last instrument in a song string and returns the various output formats. Args: text_sequence (str): The song string. qpm (int, optional): The quarter notes per minute. Defaults to 120. Returns: Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, instruments string, new song string, and number of tokens string. """ def remove_last_track(text_sequence): tracks = text_sequence.split("TRACK_START") # We keep all tracks except the last one useful_tracks = tracks[:-1] # We join the tracks back together, adding back the 'TRACK_START' that was removed by split text_sequence = "TRACK_START".join(useful_tracks) return text_sequence #last_inst_index = text_sequence.rfind("INST=") for token in reversed(text_sequence.split()): if 'INST=' in token: instrument_id = token.split('=')[1] break if instrument_id=="DRUMS": instrument="Drums" else: instrument=instruments[int(instrument_id)+1]# Index 0 instrument is 'Acoustic Grand Piano' for rendering:https://soundprogramming.net/file-formats/general-midi-instrument-list/#google_vignette new_seed = remove_last_track(text_sequence=text_sequence) audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( instrument=instrument,text_sequence=new_seed, qpm=qpm ) return audio, midi_file, fig, instruments_str, new_song, num_tokens def change_tempo( text_sequence: str, qpm: int ) -> Tuple[ndarray, str, Figure, str, str, str]: """ Changes the tempo of a song string and returns the various output formats. Args: text_sequence (str): The song string. qpm (int): The new quarter notes per minute. Returns: Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, instruments string, text sequence, and number of tokens string. """ audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( text_sequence, qpm=qpm ) return audio, midi_file, fig, instruments_str, text_sequence, num_tokens def generate_song( genre: str = "OTHER", artist: str = "KATE_BUSH", instrument: str = "Acoustic Grand Piano", temp: float = 0.75, text_sequence: str = "", qpm: int = 120 ) -> Tuple[ndarray, str, Figure, str, str, str]: """ Generates a song given a genre, temperature, initial text sequence, and tempo. Args: model (AutoModelForCausalLM): The pretrained model used for generating the sequences. tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences. genre (str, optional): The genre of the song. Defaults to "OTHER". artist (str, optional): The artist style to inspire the song. Defaults to "KATE_BUSH". temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75. text_sequence (str, optional): The initial text sequence for the song. Defaults to "". qpm (int, optional): The quarter notes per minute. Defaults to 120. Returns: Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, instruments string, generated song string, and number of tokens string. """ instrument = instruments.index(instrument) #Drums if instrument == 0: instrument='DRUMS' else: instrument = str(instrument-1) if text_sequence == "": seed_string = create_seed_string(genre, artist, instrument) else: seed_string = text_sequence + " TRACK_START INST=" + instrument generated_sequence = generate_new_instrument(seed=seed_string, temp=temp) audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( generated_sequence, qpm ) return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens