hoyoMusic / app.py
MuGeminorum
refine codes
19b4090
raw
history blame
No virus
6.73 kB
import re
import os
import time
import torch
import shutil
import argparse
import gradio as gr
from utils import *
from config import *
from convert import *
from transformers import GPT2Config
import warnings
warnings.filterwarnings("ignore")
def get_args(parser):
parser.add_argument(
"-num_tunes",
type=int,
default=1,
help="the number of independently computed returned tunes",
)
parser.add_argument(
"-max_patch",
type=int,
default=128,
help="integer to define the maximum length in tokens of each tune",
)
parser.add_argument(
"-top_p",
type=float,
default=0.8,
help="float to define the tokens that are within the sample operation of text generation",
)
parser.add_argument(
"-top_k",
type=int,
default=8,
help="integer to define the tokens that are within the sample operation of text generation",
)
parser.add_argument(
"-temperature",
type=float,
default=1.2,
help="the temperature of the sampling operation",
)
parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
parser.add_argument(
"-show_control_code",
type=bool,
default=True,
help="whether to show control code",
)
args = parser.parse_args()
return args
def generate_abc(args, region):
patchilizer = Patchilizer()
patch_config = GPT2Config(
num_hidden_layers=PATCH_NUM_LAYERS,
max_length=PATCH_LENGTH,
max_position_embeddings=PATCH_LENGTH,
vocab_size=1,
)
char_config = GPT2Config(
num_hidden_layers=CHAR_NUM_LAYERS,
max_length=PATCH_SIZE,
max_position_embeddings=PATCH_SIZE,
vocab_size=128,
)
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
filename = WEIGHT_PATH
if os.path.exists(filename):
print(f"Weights already exist at '{filename}'. Loading...")
else:
download()
checkpoint = torch.load(filename, map_location=torch.device("cpu"))
model.load_state_dict(checkpoint["model"])
model = model.to(device)
model.eval()
prompt = template(region)
tunes = ""
num_tunes = args.num_tunes
max_patch = args.max_patch
top_p = args.top_p
top_k = args.top_k
temperature = args.temperature
seed = args.seed
show_control_code = args.show_control_code
print(" HYPERPARAMETERS ".center(60, "#"), "\n")
args = vars(args)
for key in args.keys():
print(f"{key}: {str(args[key])}")
print("\n", " OUTPUT TUNES ".center(60, "#"))
start_time = time.time()
for i in range(num_tunes):
title_artist = f"T:{region} Fragment\nC:Generated by AI\n"
tune = f"X:{str(i + 1)}\n{title_artist + prompt}"
lines = re.split(r"(\n)", tune)
tune = ""
skip = False
for line in lines:
if show_control_code or line[:2] not in ["S:", "B:", "E:"]:
if not skip:
print(line, end="")
tune += line
skip = False
else:
skip = True
input_patches = torch.tensor(
[patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=device
)
if tune == "":
tokens = None
else:
prefix = patchilizer.decode(input_patches[0])
remaining_tokens = prompt[len(prefix) :]
tokens = torch.tensor(
[patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
device=device,
)
while input_patches.shape[1] < max_patch:
predicted_patch, seed = model.generate(
input_patches,
tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
seed=seed,
)
tokens = None
if predicted_patch[0] != patchilizer.eos_token_id:
next_bar = patchilizer.decode([predicted_patch])
if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]:
print(next_bar, end="")
tune += next_bar
if next_bar == "":
break
next_bar = remaining_tokens + next_bar
remaining_tokens = ""
predicted_patch = torch.tensor(
patchilizer.bar2patch(next_bar), device=device
).unsqueeze(0)
input_patches = torch.cat(
[input_patches, predicted_patch.unsqueeze(0)], dim=1
)
else:
break
tunes += f"{tune}\n\n"
print("\n")
print("Generation time: {:.2f} seconds".format(time.time() - start_time))
os.makedirs("./tmp", exist_ok=True)
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
out_midi = abc_to_midi(tunes, f"./tmp/[{region}]{timestamp}.mid")
out_xml = abc_to_musicxml(tunes, f"./tmp/[{region}]{timestamp}.musicxml")
out_mxl = musicxml_to_mxl(f"./tmp/[{region}]{timestamp}.musicxml")
pdf_file, jpg_file = mxl2jpg(out_mxl)
wav_file = midi2wav(out_midi)
return tunes, out_midi, pdf_file, out_xml, out_mxl, jpg_file, wav_file
def inference(region):
if os.path.exists("./tmp"):
shutil.rmtree("./tmp")
parser = argparse.ArgumentParser()
args = get_args(parser)
return generate_abc(args, region)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
region_opt = gr.Dropdown(
choices=["Mondstadt", "Liyue", "Inazuma", "Sumeru", "Fontaine"],
value="Mondstadt",
label="Region genre",
)
gen_btn = gr.Button("Generate")
with gr.Column():
wav_output = gr.Audio(label="Audio", type="filepath")
dld_midi = gr.components.File(label="Download MIDI")
pdf_score = gr.components.File(label="Download PDF score")
dld_xml = gr.components.File(label="Download MusicXML")
dld_mxl = gr.components.File(label="Download MXL")
abc_output = gr.Textbox(label="abc score", show_copy_button=True)
img_score = gr.Image(label="Staff", type="filepath")
gen_btn.click(
inference,
inputs=region_opt,
outputs=[
abc_output,
dld_midi,
pdf_score,
dld_xml,
dld_mxl,
img_score,
wav_output,
],
)
demo.launch(share=True)