Spaces:
Running
Running
import re | |
import os | |
import time | |
import torch | |
import shutil | |
import argparse | |
import warnings | |
import gradio as gr | |
from transformers import GPT2Config | |
from model import Patchilizer, TunesFormer | |
from convert import abc2xml, xml2, xml2img | |
from utils import ( | |
PATCH_NUM_LAYERS, | |
PATCH_LENGTH, | |
CHAR_NUM_LAYERS, | |
PATCH_SIZE, | |
SHARE_WEIGHTS, | |
WEIGHTS_PATH, | |
TEMP_DIR, | |
TEYVAT, | |
DEVICE, | |
EN_US, | |
_L, | |
) | |
def get_args(parser: argparse.ArgumentParser): | |
parser.add_argument( | |
"-num_tunes", | |
type=int, | |
default=1, | |
help="the number of independently computed returned tunes", | |
) | |
parser.add_argument( | |
"-max_patch", | |
type=int, | |
default=128, | |
help="integer to define the maximum length in tokens of each tune", | |
) | |
parser.add_argument( | |
"-top_p", | |
type=float, | |
default=0.8, | |
help="float to define the tokens that are within the sample operation of text generation", | |
) | |
parser.add_argument( | |
"-top_k", | |
type=int, | |
default=8, | |
help="integer to define the tokens that are within the sample operation of text generation", | |
) | |
parser.add_argument( | |
"-temperature", | |
type=float, | |
default=1.2, | |
help="the temperature of the sampling operation", | |
) | |
parser.add_argument("-seed", type=int, default=None, help="seed for randomstate") | |
parser.add_argument( | |
"-show_control_code", | |
type=bool, | |
default=False, | |
help="whether to show control code", | |
) | |
return parser.parse_args() | |
def generate_music(args, region: str): | |
patchilizer = Patchilizer() | |
patch_config = GPT2Config( | |
num_hidden_layers=PATCH_NUM_LAYERS, | |
max_length=PATCH_LENGTH, | |
max_position_embeddings=PATCH_LENGTH, | |
vocab_size=1, | |
) | |
char_config = GPT2Config( | |
num_hidden_layers=CHAR_NUM_LAYERS, | |
max_length=PATCH_SIZE, | |
max_position_embeddings=PATCH_SIZE, | |
vocab_size=128, | |
) | |
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) | |
checkpoint = torch.load(WEIGHTS_PATH, map_location=torch.device("cpu")) | |
model.load_state_dict(checkpoint["model"]) | |
model = model.to(DEVICE) | |
model.eval() | |
prompt = f"A:{region}\n" | |
tunes = "" | |
num_tunes = args.num_tunes | |
max_patch = args.max_patch | |
top_p = args.top_p | |
top_k = args.top_k | |
temperature = args.temperature | |
seed = args.seed | |
show_control_code = args.show_control_code | |
print(" Hyper parms ".center(60, "#"), "\n") | |
arg_dict: dict = vars(args) | |
for key in arg_dict.keys(): | |
print(f"{key}: {str(arg_dict[key])}") | |
print("\n", " Output tunes ".center(60, "#")) | |
start_time = time.time() | |
for i in range(num_tunes): | |
title_artist = f"T:{region} Style Fragment\nC:Generated by AI\n" | |
tune = f"X:{str(i + 1)}\n{title_artist + prompt}" | |
lines = re.split(r"(\n)", tune) | |
tune = "" | |
skip = False | |
for line in lines: | |
if show_control_code or line[:2] not in ["S:", "B:", "E:"]: | |
if not skip: | |
print(line, end="") | |
tune += line | |
skip = False | |
else: | |
skip = True | |
input_patches = torch.tensor( | |
[patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=DEVICE | |
) | |
if tune == "": | |
tokens = None | |
else: | |
prefix = patchilizer.decode(input_patches[0]) | |
remaining_tokens = prompt[len(prefix) :] | |
tokens = torch.tensor( | |
[patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens], | |
device=DEVICE, | |
) | |
while input_patches.shape[1] < max_patch: | |
predicted_patch, seed = model.generate( | |
input_patches, | |
tokens, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
seed=seed, | |
) | |
tokens = None | |
if predicted_patch[0] != patchilizer.eos_token_id: | |
next_bar = patchilizer.decode([predicted_patch]) | |
if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]: | |
print(next_bar, end="") | |
tune += next_bar | |
if next_bar == "": | |
break | |
next_bar = remaining_tokens + next_bar | |
remaining_tokens = "" | |
predicted_patch = torch.tensor( | |
patchilizer.bar2patch(next_bar), device=DEVICE | |
).unsqueeze(0) | |
input_patches = torch.cat( | |
[input_patches, predicted_patch.unsqueeze(0)], dim=1 | |
) | |
else: | |
break | |
tunes += f"{tune}\n\n" | |
print("\n") | |
print("Generation time: {:.2f} seconds".format(time.time() - start_time)) | |
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) | |
try: | |
xml = abc2xml(tunes, f"{TEMP_DIR}/[{region}]{timestamp}.musicxml") | |
midi = xml2(xml, "mid") | |
audio = xml2(xml, "wav") | |
pdf, jpg = xml2img(xml) | |
mxl = xml2(xml, "mxl") | |
return audio, midi, pdf, xml, mxl, tunes, jpg | |
except Exception as e: | |
print(f"Invalid abc generated: {e}, retrying...") | |
return generate_music(args, region) | |
def infer(p: float, k: int, t: float, region: str): | |
status = "Success" | |
audio = midi = pdf = xml = mxl = tunes = jpg = None | |
try: | |
if os.path.exists(TEMP_DIR): | |
shutil.rmtree(TEMP_DIR) | |
os.makedirs(TEMP_DIR) | |
parser = argparse.ArgumentParser() | |
args = get_args(parser) | |
args.top_p = p | |
args.top_k = k | |
args.temperature = t | |
audio, midi, pdf, xml, mxl, tunes, jpg = generate_music( | |
args, region if EN_US else TEYVAT[region] | |
) | |
except Exception as e: | |
status = f"{e}" | |
return status, audio, midi, pdf, xml, mxl, tunes, jpg | |
if __name__ == "__main__": | |
warnings.filterwarnings("ignore") | |
opts = list(TEYVAT.values()) if EN_US else list(TEYVAT.keys()) | |
gr.Interface( | |
fn=infer, | |
inputs=[ | |
gr.Slider(0.01, 1.0, 0.8, step=0.01, label=_L("Top-P 采样")), | |
gr.Slider(0, 80, 8, step=1, label=_L("Top-K 采样 (0 为关闭)")), | |
gr.Slider(0.01, 2.0, 1.2, step=0.01, label=_L("温度参数")), | |
gr.Dropdown( | |
choices=opts, | |
value=opts[0], | |
label=_L("地区风格"), | |
), | |
], | |
outputs=[ | |
gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
gr.Audio(label=_L("音频"), type="filepath"), | |
gr.File(label=_L("下载 MIDI")), | |
gr.File(label=_L("下载 PDF 乐谱")), | |
gr.File(label=_L("下载 MusicXML")), | |
gr.File(label=_L("下载 MXL")), | |
gr.Textbox(label=_L("ABC 记谱"), show_copy_button=True), | |
gr.Image(label=_L("五线谱"), type="filepath", show_share_button=False), | |
], | |
flagging_mode="never", | |
title=_L("原神音乐生成"), | |
description=_L( | |
""" | |
欢迎使用此创空间, 此创空间基于 Tunesformer 开源项目制作,完全免费。当前模型还在调试中,计划在原神主线杀青后,所有国家地区角色全部开放后,二创音乐会齐全且样本均衡,届时重新微调模型并添加现实风格筛选辅助游戏各国家输出强化学习,以提升输出区分度与质量。注:崩铁方面数据工程正在运作中,未来也希望随主线杀青而基线化。<br> | |
数据来源: <a href="https://musescore.org">MuseScore</a> 标签来源: <a href="https://genshin-impact.fandom.com/wiki/Genshin_Impact_Wiki">Genshin Impact Wiki | Fandom</a> 模型基础: <a href="https://github.com/sander-wood/tunesformer">Tunesformer</a> | |
""" | |
), | |
).launch(ssr_mode=False) | |