PapaGEN / utils.py
MikeMpapa's picture
Update utils.py
c9d7e0a verified
from typing import List, Tuple
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import note_seq
from matplotlib.figure import Figure
from numpy import ndarray
import torch
from constants import GM_INSTRUMENTS, SAMPLE_RATE
from string_to_notes import token_sequence_to_note_sequence
from model import get_model_and_tokenizer
import json
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the tokenizer and the model
model, tokenizer = get_model_and_tokenizer()
# Instruments
with open('instruments.json', 'r') as f:
instruments = json.load(f)
def create_seed_string(genre: str = "OTHER", artist: str = "OTHER", instrument:str="0") -> str:
"""
Creates a seed string for generating a new piece.
Args:
genre (str, optional): The genre of the piece. Defaults to "OTHER".
Returns:
str: The seed string.
"""
if genre == "RANDOM" and artist == "RANDOM":
seed_string = f"PIECE_START GENRE=RANDOM ARTIST=RANDOM TRACK_START INST={instrument}"
elif genre == "RANDOM" and artist != "RANDOM":
seed_string = f"PIECE_START GENRE=RANDOM ARTIST={artist} TRACK_START INST={instrument}"
elif genre != "RANDOM" and artist == "RANDOM":
seed_string = f"PIECE_START GENRE={genre} ARTIST=RANDOM TRACK_START INST={instrument}"
else:
seed_string = f"PIECE_START GENRE={genre} ARTIST={artist} TRACK_START INST={instrument}"
return seed_string
def get_instruments(text_sequence: str) -> List[str]:
"""
Extracts the list of instruments from a text sequence.
Args:
text_sequence (str): The text sequence.
Returns:
List[str]: The list of instruments.
"""
instruments = []
parts = text_sequence.split()
for part in parts:
if part.startswith("INST="):
if part[5:] == "DRUMS":
instruments.append("Drums")
else:
index = int(part[5:])
instruments.append(GM_INSTRUMENTS[index])
return instruments
def change_last_instrument( text_sequence: str,
instrument: str,
temp: float = 0.75,
qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
instrument_idx = instruments.index(instrument)
#Drums
if instrument_idx == 0:
instrument_idx='DRUMS'
else:
instrument_idx = str(instrument_idx-1)
text_sequence = text_sequence.split()
for token_idx in reversed(range(len(text_sequence))):
if "INST=" in text_sequence[token_idx]:
text_sequence[token_idx] = f"INST={instrument_idx}"
break
text_sequence = (' ').join(text_sequence)
#print(text_sequence)
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
text_sequence, qpm
)
# print(type(audio),audio)
# print(type(midi_file),midi_file)
# print(type(fig),fig)
# print(type(instruments_str),instruments_str)
# print(type(num_tokens),num_tokens)
return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
def generate_new_instrument(seed: str, temp: float = 0.75) -> str:
"""
Generates a new instrument sequence from a given seed and temperature.
Args:
seed (str): The seed string for the generation.
temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
Returns:
str: The generated instrument sequence.
"""
seed_length = len(tokenizer.encode(seed))
while True:
# Encode the conditioning tokens.
input_ids = tokenizer.encode(seed, return_tensors="pt")
# Move the input_ids tensor to the same device as the model
input_ids = input_ids.to(model.device)
# Generate more tokens.
eos_token_id = tokenizer.encode("TRACK_END")[0]
generated_ids = model.generate(
input_ids,
max_new_tokens=2048,
do_sample=True,
temperature=temp,
eos_token_id=eos_token_id,
)
generated_sequence = tokenizer.decode(generated_ids[0])
# Check if the generated sequence contains "NOTE_ON" beyond the seed
new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
if "NOTE_ON" in new_generated_sequence:
return generated_sequence
def get_outputs_from_string(
generated_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str]:
"""
Converts a generated sequence into various output formats including audio, MIDI, plot, etc.
Args:
generated_sequence (str): The generated sequence of tokens.
qpm (int, optional): The quarter notes per minute. Defaults to 120.
Returns:
Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure,
instruments string, and number of tokens string.
"""
instruments = get_instruments(generated_sequence)
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
synth = note_seq.fluidsynth
array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
fig = note_seq.plot_sequence(note_sequence, show_figure=False)
num_tokens = str(len(generated_sequence.split()))
audio = gr.make_waveform((SAMPLE_RATE, int16_data))
note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
return audio, "midi_ouput.mid", fig, instruments_str, num_tokens
def remove_last_instrument(
text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
"""
Removes the last instrument from a song string and returns the various output formats.
Args:
text_sequence (str): The song string.
qpm (int, optional): The quarter notes per minute. Defaults to 120.
Returns:
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
instruments string, new song string, and number of tokens string.
"""
# We split the song into tracks by splitting on 'TRACK_START'
tracks = text_sequence.split("TRACK_START")
# We keep all tracks except the last one
modified_tracks = tracks[:-1]
# We join the tracks back together, adding back the 'TRACK_START' that was removed by split
new_song = "TRACK_START".join(modified_tracks)
if len(tracks) == 2:
# There is only one instrument, so start from scratch
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
text_sequence=new_song
)
elif len(tracks) == 1:
# No instrument so start from empty sequence
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
text_sequence=""
)
else:
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
new_song, qpm
)
return audio, midi_file, fig, instruments_str, new_song, num_tokens
genre: str = "OTHER",
artist: str = "KATE_BUSH",
instrument: str = "Acoustic Grand Piano",
temp: float = 0.75,
text_sequence: str = "",
qpm: int = 120
def regenerate_last_instrument(
text_sequence: str, qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
"""
Regenerates the last instrument in a song string and returns the various output formats.
Args:
text_sequence (str): The song string.
qpm (int, optional): The quarter notes per minute. Defaults to 120.
Returns:
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
instruments string, new song string, and number of tokens string.
"""
def remove_last_track(text_sequence):
tracks = text_sequence.split("TRACK_START")
# We keep all tracks except the last one
useful_tracks = tracks[:-1]
# We join the tracks back together, adding back the 'TRACK_START' that was removed by split
text_sequence = "TRACK_START".join(useful_tracks)
return text_sequence
#last_inst_index = text_sequence.rfind("INST=")
for token in reversed(text_sequence.split()):
if 'INST=' in token:
instrument_id = token.split('=')[1]
break
if instrument_id=="DRUMS":
instrument="Drums"
else:
instrument=instruments[int(instrument_id)+1]# Index 0 instrument is 'Acoustic Grand Piano' for rendering:https://soundprogramming.net/file-formats/general-midi-instrument-list/#google_vignette
new_seed = remove_last_track(text_sequence=text_sequence)
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song(
instrument=instrument,text_sequence=new_seed, qpm=qpm
)
return audio, midi_file, fig, instruments_str, new_song, num_tokens
def change_tempo(
text_sequence: str, qpm: int
) -> Tuple[ndarray, str, Figure, str, str, str]:
"""
Changes the tempo of a song string and returns the various output formats.
Args:
text_sequence (str): The song string.
qpm (int): The new quarter notes per minute.
Returns:
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
instruments string, text sequence, and number of tokens string.
"""
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
text_sequence, qpm=qpm
)
return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
def generate_song(
genre: str = "OTHER",
artist: str = "KATE_BUSH",
instrument: str = "Acoustic Grand Piano",
temp: float = 0.75,
text_sequence: str = "",
qpm: int = 120
) -> Tuple[ndarray, str, Figure, str, str, str]:
"""
Generates a song given a genre, temperature, initial text sequence, and tempo.
Args:
model (AutoModelForCausalLM): The pretrained model used for generating the sequences.
tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences.
genre (str, optional): The genre of the song. Defaults to "OTHER".
artist (str, optional): The artist style to inspire the song. Defaults to "KATE_BUSH".
temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
text_sequence (str, optional): The initial text sequence for the song. Defaults to "".
qpm (int, optional): The quarter notes per minute. Defaults to 120.
Returns:
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure,
instruments string, generated song string, and number of tokens string.
"""
instrument = instruments.index(instrument)
#Drums
if instrument == 0:
instrument='DRUMS'
else:
instrument = str(instrument-1)
if text_sequence == "":
seed_string = create_seed_string(genre, artist, instrument)
else:
seed_string = text_sequence + " TRACK_START INST=" + instrument
generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
generated_sequence, qpm
)
return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens