|
from typing import List, Tuple |
|
|
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import note_seq |
|
from matplotlib.figure import Figure |
|
from numpy import ndarray |
|
import torch |
|
|
|
from constants import GM_INSTRUMENTS, SAMPLE_RATE |
|
from string_to_notes import token_sequence_to_note_sequence |
|
from model import get_model_and_tokenizer |
|
|
|
import json |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model, tokenizer = get_model_and_tokenizer() |
|
|
|
|
|
with open('instruments.json', 'r') as f: |
|
instruments = json.load(f) |
|
|
|
|
|
def create_seed_string(genre: str = "OTHER", artist: str = "OTHER", instrument:str="0") -> str: |
|
""" |
|
Creates a seed string for generating a new piece. |
|
|
|
Args: |
|
genre (str, optional): The genre of the piece. Defaults to "OTHER". |
|
|
|
Returns: |
|
str: The seed string. |
|
""" |
|
if genre == "RANDOM" and artist == "RANDOM": |
|
seed_string = f"PIECE_START GENRE=RANDOM ARTIST=RANDOM TRACK_START INST={instrument}" |
|
elif genre == "RANDOM" and artist != "RANDOM": |
|
seed_string = f"PIECE_START GENRE=RANDOM ARTIST={artist} TRACK_START INST={instrument}" |
|
elif genre != "RANDOM" and artist == "RANDOM": |
|
seed_string = f"PIECE_START GENRE={genre} ARTIST=RANDOM TRACK_START INST={instrument}" |
|
else: |
|
seed_string = f"PIECE_START GENRE={genre} ARTIST={artist} TRACK_START INST={instrument}" |
|
return seed_string |
|
|
|
|
|
def get_instruments(text_sequence: str) -> List[str]: |
|
""" |
|
Extracts the list of instruments from a text sequence. |
|
|
|
Args: |
|
text_sequence (str): The text sequence. |
|
|
|
Returns: |
|
List[str]: The list of instruments. |
|
""" |
|
instruments = [] |
|
parts = text_sequence.split() |
|
for part in parts: |
|
if part.startswith("INST="): |
|
if part[5:] == "DRUMS": |
|
instruments.append("Drums") |
|
else: |
|
index = int(part[5:]) |
|
instruments.append(GM_INSTRUMENTS[index]) |
|
return instruments |
|
|
|
|
|
def change_last_instrument( text_sequence: str, |
|
instrument: str, |
|
temp: float = 0.75, |
|
qpm: int = 120 |
|
) -> Tuple[ndarray, str, Figure, str, str, str]: |
|
|
|
|
|
instrument_idx = instruments.index(instrument) |
|
|
|
if instrument_idx == 0: |
|
instrument_idx='DRUMS' |
|
else: |
|
instrument_idx = str(instrument_idx-1) |
|
text_sequence = text_sequence.split() |
|
for token_idx in reversed(range(len(text_sequence))): |
|
if "INST=" in text_sequence[token_idx]: |
|
text_sequence[token_idx] = f"INST={instrument_idx}" |
|
break |
|
text_sequence = (' ').join(text_sequence) |
|
|
|
|
|
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( |
|
text_sequence, qpm |
|
) |
|
|
|
|
|
|
|
|
|
|
|
return audio, midi_file, fig, instruments_str, text_sequence, num_tokens |
|
|
|
|
|
|
|
def generate_new_instrument(seed: str, temp: float = 0.75) -> str: |
|
""" |
|
Generates a new instrument sequence from a given seed and temperature. |
|
|
|
Args: |
|
seed (str): The seed string for the generation. |
|
temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75. |
|
|
|
Returns: |
|
str: The generated instrument sequence. |
|
""" |
|
seed_length = len(tokenizer.encode(seed)) |
|
|
|
while True: |
|
|
|
input_ids = tokenizer.encode(seed, return_tensors="pt") |
|
|
|
|
|
input_ids = input_ids.to(model.device) |
|
|
|
|
|
eos_token_id = tokenizer.encode("TRACK_END")[0] |
|
generated_ids = model.generate( |
|
input_ids, |
|
max_new_tokens=2048, |
|
do_sample=True, |
|
temperature=temp, |
|
eos_token_id=eos_token_id, |
|
) |
|
generated_sequence = tokenizer.decode(generated_ids[0]) |
|
|
|
|
|
new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:]) |
|
if "NOTE_ON" in new_generated_sequence: |
|
return generated_sequence |
|
|
|
|
|
def get_outputs_from_string( |
|
generated_sequence: str, qpm: int = 120 |
|
) -> Tuple[ndarray, str, Figure, str, str]: |
|
""" |
|
Converts a generated sequence into various output formats including audio, MIDI, plot, etc. |
|
|
|
Args: |
|
generated_sequence (str): The generated sequence of tokens. |
|
qpm (int, optional): The quarter notes per minute. Defaults to 120. |
|
|
|
Returns: |
|
Tuple[ndarray, str, Figure, str, str]: The audio waveform, MIDI file name, plot figure, |
|
instruments string, and number of tokens string. |
|
""" |
|
instruments = get_instruments(generated_sequence) |
|
instruments_str = "\n".join(f"- {instrument}" for instrument in instruments) |
|
note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm) |
|
|
|
synth = note_seq.fluidsynth |
|
array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE) |
|
int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats) |
|
fig = note_seq.plot_sequence(note_sequence, show_figure=False) |
|
num_tokens = str(len(generated_sequence.split())) |
|
audio = gr.make_waveform((SAMPLE_RATE, int16_data)) |
|
note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid") |
|
return audio, "midi_ouput.mid", fig, instruments_str, num_tokens |
|
|
|
|
|
def remove_last_instrument( |
|
text_sequence: str, qpm: int = 120 |
|
) -> Tuple[ndarray, str, Figure, str, str, str]: |
|
""" |
|
Removes the last instrument from a song string and returns the various output formats. |
|
|
|
Args: |
|
text_sequence (str): The song string. |
|
qpm (int, optional): The quarter notes per minute. Defaults to 120. |
|
|
|
Returns: |
|
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, |
|
instruments string, new song string, and number of tokens string. |
|
""" |
|
|
|
tracks = text_sequence.split("TRACK_START") |
|
|
|
modified_tracks = tracks[:-1] |
|
|
|
new_song = "TRACK_START".join(modified_tracks) |
|
|
|
if len(tracks) == 2: |
|
|
|
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( |
|
text_sequence=new_song |
|
) |
|
elif len(tracks) == 1: |
|
|
|
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( |
|
text_sequence="" |
|
) |
|
else: |
|
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( |
|
new_song, qpm |
|
) |
|
|
|
return audio, midi_file, fig, instruments_str, new_song, num_tokens |
|
|
|
|
|
genre: str = "OTHER", |
|
artist: str = "KATE_BUSH", |
|
instrument: str = "Acoustic Grand Piano", |
|
temp: float = 0.75, |
|
text_sequence: str = "", |
|
qpm: int = 120 |
|
|
|
def regenerate_last_instrument( |
|
text_sequence: str, qpm: int = 120 |
|
) -> Tuple[ndarray, str, Figure, str, str, str]: |
|
""" |
|
Regenerates the last instrument in a song string and returns the various output formats. |
|
|
|
Args: |
|
text_sequence (str): The song string. |
|
qpm (int, optional): The quarter notes per minute. Defaults to 120. |
|
|
|
Returns: |
|
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, |
|
instruments string, new song string, and number of tokens string. |
|
""" |
|
|
|
def remove_last_track(text_sequence): |
|
tracks = text_sequence.split("TRACK_START") |
|
|
|
useful_tracks = tracks[:-1] |
|
|
|
text_sequence = "TRACK_START".join(useful_tracks) |
|
return text_sequence |
|
|
|
|
|
|
|
for token in reversed(text_sequence.split()): |
|
if 'INST=' in token: |
|
instrument_id = token.split('=')[1] |
|
break |
|
|
|
if instrument_id=="DRUMS": |
|
instrument="Drums" |
|
else: |
|
instrument=instruments[int(instrument_id)+1] |
|
|
|
new_seed = remove_last_track(text_sequence=text_sequence) |
|
|
|
audio, midi_file, fig, instruments_str, new_song, num_tokens = generate_song( |
|
instrument=instrument,text_sequence=new_seed, qpm=qpm |
|
) |
|
return audio, midi_file, fig, instruments_str, new_song, num_tokens |
|
|
|
|
|
def change_tempo( |
|
text_sequence: str, qpm: int |
|
) -> Tuple[ndarray, str, Figure, str, str, str]: |
|
""" |
|
Changes the tempo of a song string and returns the various output formats. |
|
|
|
Args: |
|
text_sequence (str): The song string. |
|
qpm (int): The new quarter notes per minute. |
|
|
|
Returns: |
|
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, |
|
instruments string, text sequence, and number of tokens string. |
|
""" |
|
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( |
|
text_sequence, qpm=qpm |
|
) |
|
return audio, midi_file, fig, instruments_str, text_sequence, num_tokens |
|
|
|
|
|
def generate_song( |
|
genre: str = "OTHER", |
|
artist: str = "KATE_BUSH", |
|
instrument: str = "Acoustic Grand Piano", |
|
temp: float = 0.75, |
|
text_sequence: str = "", |
|
qpm: int = 120 |
|
) -> Tuple[ndarray, str, Figure, str, str, str]: |
|
""" |
|
Generates a song given a genre, temperature, initial text sequence, and tempo. |
|
|
|
Args: |
|
model (AutoModelForCausalLM): The pretrained model used for generating the sequences. |
|
tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences. |
|
genre (str, optional): The genre of the song. Defaults to "OTHER". |
|
artist (str, optional): The artist style to inspire the song. Defaults to "KATE_BUSH". |
|
temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75. |
|
text_sequence (str, optional): The initial text sequence for the song. Defaults to "". |
|
qpm (int, optional): The quarter notes per minute. Defaults to 120. |
|
|
|
Returns: |
|
Tuple[ndarray, str, Figure, str, str, str]: The audio waveform, MIDI file name, plot figure, |
|
instruments string, generated song string, and number of tokens string. |
|
""" |
|
instrument = instruments.index(instrument) |
|
|
|
if instrument == 0: |
|
instrument='DRUMS' |
|
else: |
|
instrument = str(instrument-1) |
|
|
|
if text_sequence == "": |
|
seed_string = create_seed_string(genre, artist, instrument) |
|
else: |
|
seed_string = text_sequence + " TRACK_START INST=" + instrument |
|
|
|
generated_sequence = generate_new_instrument(seed=seed_string, temp=temp) |
|
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string( |
|
generated_sequence, qpm |
|
) |
|
return audio, midi_file, fig, instruments_str, generated_sequence, num_tokens |
|
|
|
|