import random
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from model.music_transformer import MusicTransformer
from processor import decode_midi, encode_midi
from utilities.constants import TOKEN_END, TOKEN_PAD, TORCH_LABEL_TYPE
from utilities.device import get_device, use_cuda
REPO_ID = "Launchpad/lofi-bytes"
FILENAME = "weights_maestro_finetuned.pickle"
SEQUENCE_START = 0
OUTPUT_PATH = "./output_midi"
RPR = True
# TARGET_SEQ_LENGTH = 1023
TARGET_SEQ_LENGTH = 512
NUM_PRIME = 65
MAX_SEQUENCE = 2048
N_LAYERS = 6
NUM_HEADS = 8
D_MODEL = 512
DIM_FEEDFORWARD = 1024
BEAM = 0
FORCE_CPU = False
ALLOWED_EXTENSIONS = {'mid'}
UPLOAD_FOLDER = './uploaded_midis'
generated_midi = None
use_cuda(True)
model = MusicTransformer(
n_layers=N_LAYERS,
num_heads=NUM_HEADS,
d_model=D_MODEL,
dim_feedforward=DIM_FEEDFORWARD,
max_sequence=MAX_SEQUENCE,
rpr=RPR
).to(get_device())
state_dict = torch.load(
hf_hub_download(repo_id=REPO_ID, filename=FILENAME),
map_location=get_device()
)
model.load_state_dict(state_dict)
def generate(input_midi):
raw_mid = encode_midi(input_midi)
if(len(raw_mid) == 0):
return
primer, _ = process_midi(raw_mid, NUM_PRIME, random_seq=False)
primer = torch.tensor(primer, dtype=TORCH_LABEL_TYPE, device=get_device())
# saves a pretty_midi at file_path
# decode_midi(primer[:NUM_PRIME].cpu().numpy(), file_path=f_path)
decode_midi(primer[:NUM_PRIME].cpu().numpy())
# GENERATION
model.eval()
with torch.set_grad_enabled(False):
# NOTE: model.generate() returns a MIDI stored as an ARRAY given a primer
beam_seq = model.generate(primer[:NUM_PRIME], TARGET_SEQ_LENGTH, beam=BEAM)
file_path = "output.mid"
# NOTE: function decode_midi() returns an actual MIDI of class pretty_midi.PrettyMIDI
decoded_midi = decode_midi(beam_seq[0].cpu().numpy(), file_path=file_path)
# THIS SHOULD BE EITHER decoded_midi OR beam_seq
# TODO: decoded_midi is actual pretty_midi MIDI file, beam_seq is just an array representing a MIDI
# decoded_midi stores more information about instruments and stuff
return file_path
def process_midi(raw_mid, max_seq, random_seq):
"""
----------
Author: Damon Gwinn
----------
Takes in pre-processed raw midi and returns the input and target. Can use a random sequence or
go from the start based on random_seq.
----------
"""
x = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
tgt = torch.full((max_seq, ), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
raw_len = len(raw_mid)
full_seq = max_seq + 1 # Performing seq2seq
if(raw_len == 0):
return x, tgt
if(raw_len < full_seq):
x[:raw_len] = raw_mid
tgt[:raw_len-1] = raw_mid[1:]
tgt[raw_len] = TOKEN_END
else:
# Randomly selecting a range
if(random_seq):
end_range = raw_len - full_seq
start = random.randint(SEQUENCE_START, end_range)
# Always taking from the start to as far as we can
else:
start = SEQUENCE_START
end = start + full_seq
data = raw_mid[start:end]
x = data[:max_seq]
tgt = data[1:full_seq]
return x, tgt
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Image(
"https://www.ocf.berkeley.edu/~launchpad/media/uploads/project_logos/410912267_278779401866686_2517511436172822307_n_0iVwDxI.png",
elem_id="logo-img",
show_label=False,
show_share_button=False,
show_download_button=False,
show_fullscreen_button=False,
)
with gr.Column(scale=3):
gr.Markdown("""lofi-bytes is a [Launchpad](https://launchpad.studentorg.berkeley.edu/) project (Spring 2023) that generates lofi tracks from input MIDI stamples using a MusicTransformer model.
**Model**: [lofi-bytes](https://huggingface.co/Launchpad/lofi-bytes)
**Project Leader**: Alicia Wang
**Members**: Alena Chao, Eric Liu, Zane Mogannam, Chloe Wong, Iris Zhou
**Advisors**: Vincent Lim, Winston Liu
"""
)
gr.Interface(
fn=generate,
inputs=gr.File(),
outputs=gr.File(),
examples=["uploaded_midis/ghibli_castle_in_the_sky.mid", "uploaded_midis/am_i_blue_jazz.mid"]
)
if __name__ == '__main__':
demo.launch(share=True, show_error=True)