Spaces:
Build error
Build error
from typing import List, Tuple | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import note_seq | |
from matplotlib.figure import Figure | |
from numpy import ndarray | |
import torch | |
from constants import GM_INSTRUMENTS, SAMPLE_RATE | |
from string_to_notes import token_sequence_to_note_sequence | |
from model import get_model_and_tokenizer | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load the tokenizer and the model | |
model, tokenizer = get_model_and_tokenizer() | |
def create_seed_string(genre: str = "OTHER") -> str: | |
""" | |
Creates a seed string for generating a new piece. | |
Args: | |
genre (str, optional): The genre of the piece. Defaults to "OTHER". | |
Returns: | |
str: The seed string. | |
""" | |
if genre == "RANDOM": | |
seed_string = "PIECE_START" | |
else: | |
seed_string = f"PIECE_START GENRE={genre} TRACK_START" | |
return seed_string | |
def get_instruments(text_sequence: str) -> List[str]: | |
""" | |
Extracts the list of instruments from a text sequence. | |
Args: | |
text_sequence (str): The text sequence. | |
Returns: | |
List[str]: The list of instruments. | |
""" | |
instruments = [] | |
parts = text_sequence.split() | |
for part in parts: | |
if part.startswith("INST="): | |
if part[5:] == "DRUMS": | |
instruments.append("Drums") | |
else: | |
index = int(part[5:]) | |
instruments.append(GM_INSTRUMENTS[index]) | |
return instruments | |
def generate_new_instrument(seed: str, temp: float = 0.75) -> str: | |
""" | |
Generates a new instrument sequence from a given seed and temperature. | |
Args: | |
seed (str): The seed string for the generation. | |
temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75. | |
Returns: | |
str: The generated instrument sequence. | |
""" | |
seed_length = len(tokenizer.encode(seed)) | |
while True: | |
# Encode the conditioning tokens. | |
input_ids = tokenizer.encode(seed, return_tensors="pt") | |
# Move the input_ids tensor to the same device as the model | |
input_ids = input_ids.to(model.device) | |
# Generate more tokens. | |
eos_token_id = tokenizer.encode("TRACK_END")[0] | |
generated_ids = model.generate( | |
input_ids, | |
max_new_tokens=2048, | |
do_sample=True, | |
temperature=temp, | |
eos_token_id=eos_token_id, | |
) | |
generated_sequence = tokenizer.decode(generated_ids[0]) | |
# Check if the generated sequence contains "NOTE_ON" beyond the seed | |
new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:]) | |
if "NOTE_ON" in new_generated_sequence: | |
return generated_sequence | |
def get_outputs_from_string( | |
generated_sequence: str, qpm: int = 120 | |
) -> Tuple[ndarray, str, Figure, str, str]: | |
""" | |
Converts a generated sequence into various output formats including audio, MIDI, plot, etc. | |
Args: | |
generated_sequence (str): The generated sequence of tokens. | |
qpm (int, optional): The quarter notes per minute. Defaults to 120. | |
Returns: | |
Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure, | |
instruments string, and number of tokens string. | |
""" | |
instruments = get_instruments(generated_sequence) | |
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments) | |
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm) | |
synth = note_seq.fluidsynth | |
array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE) | |
int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats) | |
fig = note_seq.plot_sequence(note_sequence, show_figure=False) | |
num_tokens = str(len(generated_sequence.split())) | |
audio = gr.make_waveform((SAMPLE_RATE, int16_data)) | |
note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid") | |
return audio, "midi_ouput.mid", fig, instruments_str, num_tokens | |
def remove_last_instrument( | |
text_sequence: str, qpm: int = 120 | |
) -> Tuple[ndarray, str, Figure, str, str, str]: | |
""" | |
Removes the last instrument from a song string and returns the various output formats. | |
Args: | |
text_sequence (str): The song string. | |
qpm (int, optional): The quarter notes per minute. Defaults to 120. | |
Returns: | |
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, | |
instruments string, new song string, and number of tokens string. | |
""" | |
# We split the song into tracks by splitting on 'TRACK_START' | |
tracks = text_sequence.split("TRACK_START") | |
# We keep all tracks except the last one | |
modified_tracks = tracks[:-1] | |
# We join the tracks back together, adding back the 'TRACK_START' that was removed by split | |
new_song = "TRACK_START".join(modified_tracks) | |
if len(tracks) == 2: | |
# There is only one instrument, so start from scratch | |
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( | |
text_sequence=new_song | |
) | |
elif len(tracks) == 1: | |
# No instrument so start from empty sequence | |
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( | |
text_sequence="" | |
) | |
else: | |
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( | |
new_song, qpm | |
) | |
return audio, midi_file, fig, instruments_str, new_song, num_tokens | |
def regenerate_last_instrument( | |
text_sequence: str, qpm: int = 120 | |
) -> Tuple[ndarray, str, Figure, str, str, str]: | |
""" | |
Regenerates the last instrument in a song string and returns the various output formats. | |
Args: | |
text_sequence (str): The song string. | |
qpm (int, optional): The quarter notes per minute. Defaults to 120. | |
Returns: | |
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, | |
instruments string, new song string, and number of tokens string. | |
""" | |
last_inst_index = text_sequence.rfind("INST=") | |
if last_inst_index == -1: | |
# No instrument so start from empty sequence | |
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( | |
text_sequence="", qpm=qpm | |
) | |
else: | |
# Take it from the last instrument and continue generation | |
next_space_index = text_sequence.find(" ", last_inst_index) | |
new_seed = text_sequence[:next_space_index] | |
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( | |
text_sequence=new_seed, qpm=qpm | |
) | |
return audio, midi_file, fig, instruments_str, new_song, num_tokens | |
def change_tempo( | |
text_sequence: str, qpm: int | |
) -> Tuple[ndarray, str, Figure, str, str, str]: | |
""" | |
Changes the tempo of a song string and returns the various output formats. | |
Args: | |
text_sequence (str): The song string. | |
qpm (int): The new quarter notes per minute. | |
Returns: | |
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, | |
instruments string, text sequence, and number of tokens string. | |
""" | |
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( | |
text_sequence, qpm=qpm | |
) | |
return audio, midi_file, fig, instruments_str, text_sequence, num_tokens | |
def generate_song( | |
genre: str = "OTHER", | |
temp: float = 0.75, | |
text_sequence: str = "", | |
qpm: int = 120, | |
) -> Tuple[ndarray, str, Figure, str, str, str]: | |
""" | |
Generates a song given a genre, temperature, initial text sequence, and tempo. | |
Args: | |
model (AutoModelForCausalLM): The pretrained model used for generating the sequences. | |
tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences. | |
genre (str, optional): The genre of the song. Defaults to "OTHER". | |
temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75. | |
text_sequence (str, optional): The initial text sequence for the song. Defaults to "". | |
qpm (int, optional): The quarter notes per minute. Defaults to 120. | |
Returns: | |
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, | |
instruments string, generated song string, and number of tokens string. | |
""" | |
if text_sequence == "": | |
seed_string = create_seed_string(genre) | |
else: | |
seed_string = text_sequence | |
generated_sequence = generate_new_instrument(seed=seed_string, temp=temp) | |
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( | |
generated_sequence, qpm | |
) | |
return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens | |