import os import re import copy import time import logging import subprocess from uuid import uuid4 import gradio as gr import torch import torchaudio from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig from symusic import Score, Synthesizer import spaces os.environ['QT_QPA_PLATFORM']='offscreen' subprocess.run("chmod +x MuseScore-4.1.1.232071203-x86_64.AppImage", shell=True) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_flash_sdp(False) # log_dir os.makedirs("logs", exist_ok=True) os.makedirs("tmp", exist_ok=True) logging.basicConfig( filename=f'logs/chatmusician_server_{time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(time.time()))}.log', level=logging.WARNING, format='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) MODEL_PATH = 'm-a-p/ChatMusician' def get_uuid(): return str(uuid4()) # todo def log_conversation(conversation_id, history, messages, response, generate_kwargs): timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(time.time())) data = { "conversation_id": conversation_id, "timestamp": timestamp, "history": history, "messages": messages, "response": response, "generate_kwargs": generate_kwargs, } logging.critical(f"{data}") def _parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split("`") if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f"
" else: if i > 0: if count % 2 == 1: line = line.replace("`", r"\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) return text def convert_history_to_text(task_history): history_cp = copy.deepcopy(task_history) text = "".join( [f"Human: {item[0]} Assistant: {item[1]} " for item in history_cp[:-1] if item[0]] ) text += f"Human: {history_cp[-1][0]} Assistant: " return text # todo def postprocess_abc(text, conversation_id): os.makedirs(f"tmp/{conversation_id}", exist_ok=True) ts = time.time() abc_pattern = r'(X:\d+\n(?:[^\n]*\n)+)' abc_notation = re.findall(abc_pattern, text+'\n') print(f'extract abc block: {abc_notation}') if abc_notation: # render ABC as audio s = Score.from_abc(abc_notation[0]) audio = Synthesizer().render(s, stereo=True) audio_file = f'tmp/{conversation_id}/{ts}.mp3' torchaudio.save(audio_file, torch.FloatTensor(audio), 44100) # Convert abc notation to SVG tmp_midi = f'tmp/{conversation_id}/{ts}.mid' s.dump_midi(tmp_midi) svg_file = f'tmp/{conversation_id}/{ts}.svg' subprocess.run(["./MuseScore-4.1.1.232071203-x86_64.AppImage", "-f", "-o", svg_file, tmp_midi]) return svg_file, audio_file else: return None, None def _launch_demo(model, tokenizer): @spaces.GPU def predict(_chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id): query = task_history[-1][0] print("User: " + _parse_text(query)) # model generation messages = convert_history_to_text(task_history) inputs = tokenizer(messages, return_tensors="pt", add_special_tokens=False) generation_config = GenerationConfig( temperature=float(temperature), top_p = float(top_p), top_k = top_k, repetition_penalty = float(repetition_penalty), max_new_tokens=1536, min_new_tokens=5, do_sample=True, num_beams=1, num_return_sequences=1 ) response = model.generate( input_ids=inputs["input_ids"].to(model.device), attention_mask=inputs['attention_mask'].to(model.device), eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, generation_config=generation_config, ) response = tokenizer.decode(response[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) _chatbot[-1] = (_parse_text(query), _parse_text(response)) task_history[-1] = (_parse_text(query), response) # log log_conversation(conversation_id, task_history, messages, _chatbot[-1][1], generation_config.to_json_string()) return _chatbot, task_history def process_and_render_abc(_chatbot, task_history, conversation_id): svg_file, wav_file = None, None try: svg_file, wav_file = postprocess_abc(task_history[-1][1], conversation_id) except Exception as e: logging.error(e) if svg_file and wav_file: if os.path.exists(svg_file) and os.path.exists(wav_file): logging.critical(f"generate: svg: {svg_file} wav: {wav_file}") print(f"generate:\n{svg_file}\n{wav_file}") _chatbot.append((None, (str(wav_file),))) _chatbot.append((None, (str(svg_file),))) else: logging.error(f"fail to convert: {svg_file[:-4]}.musicxml") return _chatbot def add_text(history, task_history, text): history = history + [(_parse_text(text), None)] task_history = task_history + [(text, None)] return history, task_history, "" def reset_user_input(): return gr.update(value="") def reset_state(task_history): task_history.clear() return [] with gr.Blocks() as demo: conversation_id = gr.State(get_uuid) gr.Markdown( """

Chat Musician

""" ) gr.Markdown("""\
🌐 DemoPage  |  πŸ’» Github
  |  πŸ“– arXiv  |  πŸ€— Benchmark  |  πŸ€— Pretrain Dataset  |  πŸ€— SFT Dataset  |""") gr.Markdown("""\
πŸ’‘Note: The music clips on this page is auto-converted from abc notations which may not be perfect, and we recommend using better software for analysis.
""") chatbot = gr.Chatbot(label='Chat-Musician', elem_classes="control-height", height=750) query = gr.Textbox(lines=2, label='Input') task_history = gr.State([]) with gr.Row(): submit_btn = gr.Button("πŸš€ Submit (发送)") empty_bin = gr.Button("🧹 Clear History (ζΈ…ι™€εŽ†ε²)") # regen_btn = gr.Button("πŸ€”οΈ Regenerate (重试)") gr.Examples( examples=[ ["Create music by following the alphabetic representation of the assigned musical structure and the given motif.\n'ABCA';X:1\nL:1/16\nM:2/4\nK:A\n['E2GB d2c2 B2A2', 'D2 C2E2 A2c2']"], ["Create sheet music in ABC notation from the provided text.\nAlternative title: \nThe Legacy\nKey: G\nMeter: 6/8\nNote Length: 1/8\nRhythm: Jig\nOrigin: English\nTranscription: John Chambers"], ["Develop a melody using the given chord pattern.\n'C', 'C', 'G/D', 'D', 'G', 'C', 'G', 'G', 'C', 'C', 'F', 'C/G', 'G7', 'C'"] ], inputs=query ) with gr.Row(): with gr.Accordion("Advanced Options:", open=False): with gr.Row(): with gr.Column(): with gr.Row(): temperature = gr.Slider( label="Temperature", value=0.2, minimum=0.0, maximum=10.0, step=0.1, interactive=True, info="Higher values produce more diverse outputs", ) with gr.Column(): with gr.Row(): top_p = gr.Slider( label="Top-p (nucleus sampling)", value=0.9, minimum=0.0, maximum=1, step=0.01, interactive=True, info=( "Sample from the smallest possible set of tokens whose cumulative probability " "exceeds top_p. Set to 1 to disable and sample from all tokens." ), ) with gr.Column(): with gr.Row(): top_k = gr.Slider( label="Top-k", value=40, minimum=0.0, maximum=200, step=1, interactive=True, info="Sample from a shortlist of top-k tokens β€” 0 to disable and sample from all tokens.", ) with gr.Column(): with gr.Row(): repetition_penalty = gr.Slider( label="Repetition Penalty", value=1.1, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repetition β€” 1.0 to disable.", ) submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history], queue=False).then( predict, [chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id], [chatbot, task_history], show_progress=True, queue=True ).then(process_and_render_abc, [chatbot, task_history, conversation_id], [chatbot]) submit_btn.click(reset_user_input, [], [query]) empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) gr.Markdown( "Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce " "factually accurate information. The model was trained on various public datasets; while great efforts " "have been taken to clean the pretraining data, it is possible that this model could generate lewd, " "biased, or otherwise offensive outputs.", elem_classes=["disclaimer"], ) return demo tokenizer = AutoTokenizer.from_pretrained( MODEL_PATH ) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, device_map='cuda', torch_dtype=torch.float16 ).eval() model.generation_config = GenerationConfig.from_pretrained( MODEL_PATH ) app = _launch_demo(model, tokenizer) app.queue().launch()