Spaces:
Running
Running
import re | |
import os | |
import time | |
import torch | |
import shutil | |
import argparse | |
import gradio as gr | |
from utils import * | |
from config import * | |
from render import * | |
from music21 import converter | |
from transformers import GPT2Config | |
import warnings | |
warnings.filterwarnings('ignore') | |
def abc_to_midi(abc_content, output_midi_path): | |
# 解析 ABC 格式的乐谱 | |
score = converter.parse(abc_content) | |
# 将乐谱保存为 MIDI 文件 | |
score.write('midi', fp=output_midi_path) | |
return output_midi_path | |
def get_args(parser): | |
parser.add_argument('-num_tunes', type=int, default=1, | |
help='the number of independently computed returned tunes') | |
parser.add_argument('-max_patch', type=int, default=128, | |
help='integer to define the maximum length in tokens of each tune') | |
parser.add_argument('-top_p', type=float, default=0.8, | |
help='float to define the tokens that are within the sample operation of text generation') | |
parser.add_argument('-top_k', type=int, default=8, | |
help='integer to define the tokens that are within the sample operation of text generation') | |
parser.add_argument('-temperature', type=float, default=1.2, | |
help='the temperature of the sampling operation') | |
parser.add_argument('-seed', type=int, default=None, | |
help='seed for randomstate') | |
parser.add_argument('-show_control_code', type=bool, | |
default=True, help='whether to show control code') | |
args = parser.parse_args() | |
return args | |
def generate_abc(args, region): | |
patchilizer = Patchilizer() | |
patch_config = GPT2Config( | |
num_hidden_layers=PATCH_NUM_LAYERS, | |
max_length=PATCH_LENGTH, | |
max_position_embeddings=PATCH_LENGTH, | |
vocab_size=1 | |
) | |
char_config = GPT2Config( | |
num_hidden_layers=CHAR_NUM_LAYERS, | |
max_length=PATCH_SIZE, | |
max_position_embeddings=PATCH_SIZE, | |
vocab_size=128 | |
) | |
model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) | |
filename = WEIGHT_PATH | |
if os.path.exists(filename): | |
print(f"Weights already exist at '{filename}'. Loading...") | |
else: | |
download() | |
checkpoint = torch.load(filename, map_location=torch.device('cpu')) | |
model.load_state_dict(checkpoint['model']) | |
model = model.to(device) | |
model.eval() | |
prompt = template(region) | |
tunes = "" | |
num_tunes = args.num_tunes | |
max_patch = args.max_patch | |
top_p = args.top_p | |
top_k = args.top_k | |
temperature = args.temperature | |
seed = args.seed | |
show_control_code = args.show_control_code | |
print(" HYPERPARAMETERS ".center(60, "#"), '\n') | |
args = vars(args) | |
for key in args.keys(): | |
print(f'{key}: {str(args[key])}') | |
print('\n', " OUTPUT TUNES ".center(60, "#")) | |
start_time = time.time() | |
for i in range(num_tunes): | |
tune = f"X:{str(i + 1)}\n{prompt}" | |
lines = re.split(r'(\n)', tune) | |
tune = "" | |
skip = False | |
for line in lines: | |
if show_control_code or line[:2] not in ["S:", "B:", "E:"]: | |
if not skip: | |
print(line, end="") | |
tune += line | |
skip = False | |
else: | |
skip = True | |
input_patches = torch.tensor( | |
[patchilizer.encode(prompt, add_special_patches=True)[:-1]], | |
device=device | |
) | |
if tune == "": | |
tokens = None | |
else: | |
prefix = patchilizer.decode(input_patches[0]) | |
remaining_tokens = prompt[len(prefix):] | |
tokens = torch.tensor( | |
[patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens], | |
device=device | |
) | |
while input_patches.shape[1] < max_patch: | |
predicted_patch, seed = model.generate( | |
input_patches, | |
tokens, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
seed=seed | |
) | |
tokens = None | |
if predicted_patch[0] != patchilizer.eos_token_id: | |
next_bar = patchilizer.decode([predicted_patch]) | |
if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]: | |
print(next_bar, end="") | |
tune += next_bar | |
if next_bar == "": | |
break | |
next_bar = remaining_tokens+next_bar | |
remaining_tokens = "" | |
predicted_patch = torch.tensor( | |
patchilizer.bar2patch(next_bar), | |
device=device | |
).unsqueeze(0) | |
input_patches = torch.cat( | |
[input_patches, predicted_patch.unsqueeze(0)], | |
dim=1 | |
) | |
else: | |
break | |
tunes += f"{tune}\n\n" | |
print("\n") | |
print("Generation time: {:.2f} seconds".format(time.time() - start_time)) | |
create_dir('./tmp') | |
timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) | |
out_midi = abc_to_midi(tunes, f'./tmp/[{region}]{timestamp}.mid') | |
png_file = midi2png(out_midi) | |
wav_file = midi2wav(out_midi) | |
return tunes, out_midi, png_file, wav_file | |
def inference(region): | |
if os.path.exists('./tmp'): | |
shutil.rmtree('./tmp') | |
parser = argparse.ArgumentParser() | |
args = get_args(parser) | |
return generate_abc(args, region) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
region_opt = gr.Dropdown( | |
choices=[ | |
'Mondstadt', 'Liyue', 'Inazuma', 'Sumeru', 'Fontaine' | |
], | |
value='Liyue', | |
label='Region' | |
) | |
gen_btn = gr.Button("Generate") | |
with gr.Column(): | |
wav_output = gr.Audio(label='Audio', type='filepath') | |
dld_midi = gr.components.File(label="Download MIDI") | |
abc_output = gr.TextArea(label='abc score') | |
img_score = gr.Image(label='Staff', type='filepath') | |
gen_btn.click( | |
inference, | |
inputs=region_opt, | |
outputs=[abc_output, dld_midi, img_score, wav_output] | |
) | |
demo.launch(share=True) | |