import re import os import time import torch import shutil import argparse import gradio as gr from utils import * from config import * from render import * from music21 import converter from transformers import GPT2Config import warnings warnings.filterwarnings('ignore') def abc_to_midi(abc_content, output_midi_path): # 解析 ABC 格式的乐谱 score = converter.parse(abc_content) # 将乐谱保存为 MIDI 文件 score.write('midi', fp=output_midi_path) return output_midi_path def get_args(parser): parser.add_argument('-num_tunes', type=int, default=1, help='the number of independently computed returned tunes') parser.add_argument('-max_patch', type=int, default=128, help='integer to define the maximum length in tokens of each tune') parser.add_argument('-top_p', type=float, default=0.8, help='float to define the tokens that are within the sample operation of text generation') parser.add_argument('-top_k', type=int, default=8, help='integer to define the tokens that are within the sample operation of text generation') parser.add_argument('-temperature', type=float, default=1.2, help='the temperature of the sampling operation') parser.add_argument('-seed', type=int, default=None, help='seed for randomstate') parser.add_argument('-show_control_code', type=bool, default=True, help='whether to show control code') args = parser.parse_args() return args def generate_abc(args, region): patchilizer = Patchilizer() patch_config = GPT2Config( num_hidden_layers=PATCH_NUM_LAYERS, max_length=PATCH_LENGTH, max_position_embeddings=PATCH_LENGTH, vocab_size=1 ) char_config = GPT2Config( num_hidden_layers=CHAR_NUM_LAYERS, max_length=PATCH_SIZE, max_position_embeddings=PATCH_SIZE, vocab_size=128 ) model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS) filename = WEIGHT_PATH if os.path.exists(filename): print(f"Weights already exist at '{filename}'. Loading...") else: download() checkpoint = torch.load(filename, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model']) model = model.to(device) model.eval() prompt = template(region) tunes = "" num_tunes = args.num_tunes max_patch = args.max_patch top_p = args.top_p top_k = args.top_k temperature = args.temperature seed = args.seed show_control_code = args.show_control_code print(" HYPERPARAMETERS ".center(60, "#"), '\n') args = vars(args) for key in args.keys(): print(f'{key}: {str(args[key])}') print('\n', " OUTPUT TUNES ".center(60, "#")) start_time = time.time() for i in range(num_tunes): tune = f"X:{str(i + 1)}\n{prompt}" lines = re.split(r'(\n)', tune) tune = "" skip = False for line in lines: if show_control_code or line[:2] not in ["S:", "B:", "E:"]: if not skip: print(line, end="") tune += line skip = False else: skip = True input_patches = torch.tensor( [patchilizer.encode(prompt, add_special_patches=True)[:-1]], device=device ) if tune == "": tokens = None else: prefix = patchilizer.decode(input_patches[0]) remaining_tokens = prompt[len(prefix):] tokens = torch.tensor( [patchilizer.bos_token_id]+[ord(c) for c in remaining_tokens], device=device ) while input_patches.shape[1] < max_patch: predicted_patch, seed = model.generate( input_patches, tokens, top_p=top_p, top_k=top_k, temperature=temperature, seed=seed ) tokens = None if predicted_patch[0] != patchilizer.eos_token_id: next_bar = patchilizer.decode([predicted_patch]) if show_control_code or next_bar[:2] not in ["S:", "B:", "E:"]: print(next_bar, end="") tune += next_bar if next_bar == "": break next_bar = remaining_tokens+next_bar remaining_tokens = "" predicted_patch = torch.tensor( patchilizer.bar2patch(next_bar), device=device ).unsqueeze(0) input_patches = torch.cat( [input_patches, predicted_patch.unsqueeze(0)], dim=1 ) else: break tunes += f"{tune}\n\n" print("\n") print("Generation time: {:.2f} seconds".format(time.time() - start_time)) create_dir('./tmp') timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime()) out_midi = abc_to_midi(tunes, f'./tmp/[{region}]{timestamp}.mid') add_path() png_file = midi2png(out_midi) wav_file = midi2wav(out_midi) return tunes, out_midi, png_file, wav_file def inference(region): if os.path.exists('./tmp'): shutil.rmtree('./tmp') parser = argparse.ArgumentParser() args = get_args(parser) return generate_abc(args, region) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): region_opt = gr.Dropdown( choices=[ 'Mondstadt', 'Liyue', 'Inazuma', 'Sumeru', 'Fontaine' ], value='Liyue', label='Region' ) gen_btn = gr.Button("Generate") with gr.Column(): wav_output = gr.Audio(label='Audio', type='filepath') dld_midi = gr.components.File(label="Download MIDI") abc_output = gr.TextArea(label='abc score') img_score = gr.Image(label='Staff', type='filepath') gen_btn.click( inference, inputs=region_opt, outputs=[abc_output, dld_midi, img_score, wav_output] ) demo.launch(share=True)