import os import re import copy import time import logging import subprocess from uuid import uuid4 import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig import spaces os.environ['QT_QPA_PLATFORM']='offscreen' torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_flash_sdp(False) # log_dir os.makedirs("logs", exist_ok=True) os.makedirs("tmp", exist_ok=True) logging.basicConfig( filename=f'logs/chatmusician_server_{time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(time.time()))}.log', level=logging.WARNING, format='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) MODEL_PATH = 'm-a-p/ChatMusician' def get_uuid(): return str(uuid4()) # todo def log_conversation(conversation_id, history, messages, response, generate_kwargs): timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(time.time())) data = { "conversation_id": conversation_id, "timestamp": timestamp, "history": history, "messages": messages, "response": response, "generate_kwargs": generate_kwargs, } logging.critical(f"{data}") def _parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split("`") if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f"
" else: if i > 0: if count % 2 == 1: line = line.replace("`", r"\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) return text def convert_history_to_text(task_history): history_cp = copy.deepcopy(task_history) text = "".join( [f"Human: {item[0]} Assistant: {item[1]} " for item in history_cp[:-1] if item[0]] ) text += f"Human: {history_cp[-1][0]} Assistant: " return text # todo def postprocess_abc(text, conversation_id): os.makedirs(f"tmp/{conversation_id}", exist_ok=True) abc_pattern = r'(X:\d+\n(?:[^\n]*\n)+)' abc_notation = re.findall(abc_pattern, text+'\n') print(f'extract abc block: {abc_notation}') if abc_notation: ts = time.time() # Write the ABC text to a temporary file tmp_abc = f"tmp/{conversation_id}/{ts}.abc" with open(tmp_abc, "w") as abc_file: abc_file.write(abc_notation[0]) # Convert abc notation to midi tmp_midi = f'tmp/{conversation_id}/{ts}.mid' subprocess.run(["abc2midi", str(tmp_abc), "-o", tmp_midi]) # Convert abc notation to SVG svg_file = f'tmp/{conversation_id}/{ts}.svg' audio_file = f'tmp/{conversation_id}/{ts}.mp3' subprocess.run(["musescore", "-o", svg_file, tmp_midi], capture_output=True, text=True) subprocess.run(["musescore","-o", audio_file, tmp_midi]) return svg_file.replace(".svg", "-1.svg"), audio_file else: return None, None def _launch_demo(model, tokenizer): @spaces.GPU def predict(_chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id): query = task_history[-1][0] print("User: " + _parse_text(query)) # model generation messages = convert_history_to_text(task_history) inputs = tokenizer(messages, return_tensors="pt", add_special_tokens=False) generation_config = GenerationConfig( temperature=float(temperature), top_p = float(top_p), top_k = top_k, repetition_penalty = float(repetition_penalty), max_new_tokens=1536, min_new_tokens=5, do_sample=True, num_beams=1, num_return_sequences=1 ) response = model.generate( input_ids=inputs["input_ids"].to(model.device), attention_mask=inputs['attention_mask'].to(model.device), eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, generation_config=generation_config, ) response = tokenizer.decode(response[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) _chatbot[-1] = (_parse_text(query), _parse_text(response)) task_history[-1] = (_parse_text(query), response) # log log_conversation(conversation_id, task_history, messages, _chatbot[-1][1], generation_config.to_json_string()) return _chatbot, task_history def process_and_render_abc(_chatbot, task_history, conversation_id): svg_file, wav_file = None, None try: svg_file, wav_file = postprocess_abc(task_history[-1][1], conversation_id) except Exception as e: logging.error(e) if svg_file and wav_file: if os.path.exists(svg_file) and os.path.exists(wav_file): logging.critical(f"generate: svg: {svg_file} wav: {wav_file}") print(f"generate:\n{svg_file}\n{wav_file}") _chatbot.append((None, (str(wav_file),))) _chatbot.append((None, (str(svg_file),))) else: logging.error(f"fail to convert: {svg_file[:-4]}.musicxml") return _chatbot def add_text(history, task_history, text): history = history + [(_parse_text(text), None)] task_history = task_history + [(text, None)] return history, task_history, "" def reset_user_input(): return gr.update(value="") def reset_state(task_history): task_history.clear() return [] with gr.Blocks() as demo: conversation_id = gr.State(get_uuid) gr.Markdown( """

Chat Musician

""" ) gr.Markdown("""\
🌐 DemoPage  |  πŸ’» Github  |  πŸ“– arXiv  |  πŸ€— Benchmark  |  πŸ€— Pretrain Dataset  |  πŸ€— SFT Dataset
""") gr.Markdown("""\
πŸ’‘Note: The music clips on this page is auto-converted using musescore2 which may not be perfect, and we recommend using better software for analysis.
""") chatbot = gr.Chatbot(label='ChatMusician', elem_classes="control-height", height=750) query = gr.Textbox(lines=2, label='Input') task_history = gr.State([]) with gr.Row(): submit_btn = gr.Button("πŸš€ Submit (发送)") empty_bin = gr.Button("🧹 Clear History (ζΈ…ι™€εŽ†ε²)") # regen_btn = gr.Button("πŸ€”οΈ Regenerate (重试)") gr.Examples( examples=[ ["Create music by following the alphabetic representation of the assigned musical structure and the given motif.\n'ABCA';X:1\nL:1/16\nM:2/4\nK:A\n['E2GB d2c2 B2A2', 'D2 C2E2 A2c2']"], ["Develop a melody using the given chord pattern.\n'C', 'C', 'G/D', 'D', 'G', 'C', 'G', 'G', 'C', 'C', 'F', 'C/G', 'G7', 'C'"], ["Create sheet music in ABC notation from the provided text.\nAlternative title: \nThe Legacy\nKey: G\nMeter: 6/8\nNote Length: 1/8\nRhythm: Jig\nOrigin: English\nTranscription: John Chambers"], ], inputs=query ) with gr.Row(): with gr.Accordion("Advanced Options:", open=False): with gr.Row(): with gr.Column(): with gr.Row(): temperature = gr.Slider( label="Temperature", value=0.2, minimum=0.0, maximum=10.0, step=0.1, interactive=True, info="Higher values produce more diverse outputs", ) with gr.Column(): with gr.Row(): top_p = gr.Slider( label="Top-p (nucleus sampling)", value=0.9, minimum=0.0, maximum=1, step=0.01, interactive=True, info=( "Sample from the smallest possible set of tokens whose cumulative probability " "exceeds top_p. Set to 1 to disable and sample from all tokens." ), ) with gr.Column(): with gr.Row(): top_k = gr.Slider( label="Top-k", value=40, minimum=0.0, maximum=200, step=1, interactive=True, info="Sample from a shortlist of top-k tokens β€” 0 to disable and sample from all tokens.", ) with gr.Column(): with gr.Row(): repetition_penalty = gr.Slider( label="Repetition Penalty", value=1.1, minimum=1.0, maximum=2.0, step=0.1, interactive=True, info="Penalize repetition β€” 1.0 to disable.", ) submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history], queue=False).then( predict, [chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id], [chatbot, task_history], show_progress=True, queue=True ).then(process_and_render_abc, [chatbot, task_history, conversation_id], [chatbot]) submit_btn.click(reset_user_input, [], [query]) empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) gr.Markdown( "Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce " "factually accurate information. The model was trained on various public datasets; while great efforts " "have been taken to clean the pretraining data, it is possible that this model could generate lewd, " "biased, or otherwise offensive outputs.", elem_classes=["disclaimer"], ) return demo tokenizer = AutoTokenizer.from_pretrained( MODEL_PATH ) model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, device_map='cuda', torch_dtype=torch.float16 ).eval() model.generation_config = GenerationConfig.from_pretrained( MODEL_PATH ) app = _launch_demo(model, tokenizer) app.queue().launch()