ChatMusician / app.py
hf-lin
delete fuse
19d4c67
raw history blame
No virus
12.4 kB
import os
import re
import copy
import time
import logging
import subprocess
from uuid import uuid4
import gradio as gr
import torch
import torchaudio
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from symusic import Score, Synthesizer
import spaces
os.environ['QT_QPA_PLATFORM']='offscreen'
subprocess.run("chmod +x MuseScore-4.1.1.232071203-x86_64.AppImage", shell=True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
# log_dir
os.makedirs("logs", exist_ok=True)
os.makedirs("tmp", exist_ok=True)
logging.basicConfig(
filename=f'logs/chatmusician_server_{time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(time.time()))}.log',
level=logging.WARNING,
format='%(asctime)s [%(levelname)s]: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
MODEL_PATH = 'm-a-p/ChatMusician'
def get_uuid():
return str(uuid4())
# todo
def log_conversation(conversation_id, history, messages, response, generate_kwargs):
timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(time.time()))
data = {
"conversation_id": conversation_id,
"timestamp": timestamp,
"history": history,
"messages": messages,
"response": response,
"generate_kwargs": generate_kwargs,
}
logging.critical(f"{data}")
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def convert_history_to_text(task_history):
history_cp = copy.deepcopy(task_history)
text = "".join(
[f"Human: {item[0]} </s> Assistant: {item[1]} </s> " for item in history_cp[:-1] if item[0]]
)
text += f"Human: {history_cp[-1][0]} </s> Assistant: "
return text
# todo
def postprocess_abc(text, conversation_id):
os.makedirs(f"tmp/{conversation_id}", exist_ok=True)
ts = time.time()
abc_pattern = r'(X:\d+\n(?:[^\n]*\n)+)'
abc_notation = re.findall(abc_pattern, text+'\n')
print(f'extract abc block: {abc_notation}')
if abc_notation:
# render ABC as audio
s = Score.from_abc(abc_notation[0])
audio = Synthesizer().render(s, stereo=True)
audio_file = f'tmp/{conversation_id}/{ts}.mp3'
torchaudio.save(audio_file, torch.FloatTensor(audio), 44100)
# Convert abc notation to SVG
tmp_midi = f'tmp/{conversation_id}/{ts}.mid'
s.dump_midi(tmp_midi)
svg_file = f'tmp/{conversation_id}/{ts}.svg'
subprocess.run(["./MuseScore-4.1.1.232071203-x86_64.AppImage", "-f", "-o", svg_file, tmp_midi])
return svg_file, audio_file
else:
return None, None
def _launch_demo(model, tokenizer):
@spaces.GPU
def predict(_chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id):
query = task_history[-1][0]
print("User: " + _parse_text(query))
# model generation
messages = convert_history_to_text(task_history)
inputs = tokenizer(messages, return_tensors="pt", add_special_tokens=False)
generation_config = GenerationConfig(
temperature=float(temperature),
top_p = float(top_p),
top_k = top_k,
repetition_penalty = float(repetition_penalty),
max_new_tokens=1536,
min_new_tokens=5,
do_sample=True,
num_beams=1,
num_return_sequences=1
)
response = model.generate(
input_ids=inputs["input_ids"].to(model.device),
attention_mask=inputs['attention_mask'].to(model.device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
generation_config=generation_config,
)
response = tokenizer.decode(response[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
_chatbot[-1] = (_parse_text(query), _parse_text(response))
task_history[-1] = (_parse_text(query), response)
# log
log_conversation(conversation_id, task_history, messages, _chatbot[-1][1], generation_config.to_json_string())
return _chatbot, task_history
def process_and_render_abc(_chatbot, task_history, conversation_id):
svg_file, wav_file = None, None
try:
svg_file, wav_file = postprocess_abc(task_history[-1][1], conversation_id)
except Exception as e:
logging.error(e)
if svg_file and wav_file:
if os.path.exists(svg_file) and os.path.exists(wav_file):
logging.critical(f"generate: svg: {svg_file} wav: {wav_file}")
print(f"generate:\n{svg_file}\n{wav_file}")
_chatbot.append((None, (str(wav_file),)))
_chatbot.append((None, (str(svg_file),)))
else:
logging.error(f"fail to convert: {svg_file[:-4]}.musicxml")
return _chatbot
def add_text(history, task_history, text):
history = history + [(_parse_text(text), None)]
task_history = task_history + [(text, None)]
return history, task_history, ""
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
task_history.clear()
return []
with gr.Blocks() as demo:
conversation_id = gr.State(get_uuid)
gr.Markdown(
"""<h1><center>Chat Musician</center></h1>"""
)
gr.Markdown("""\
<center><font size=4><a href="https://ezmonyi.github.io/ChatMusician/">🌐 DemoPage</a>&nbsp |
&nbsp<a href="https://github.com/hf-lin/ChatMusician">πŸ’» Github</a></center>&nbsp |
&nbsp<a href="http://arxiv.org/abs/2402.16153">πŸ“– arXiv</a></center>&nbsp |
&nbsp<a href="https://huggingface.co/datasets/m-a-p/MusicTheoryBench">πŸ€— Benchmark</a></center>&nbsp |
&nbsp<a href="https://huggingface.co/datasets/m-a-p/MusicPile">πŸ€— Pretrain Dataset</a></center>&nbsp |
&nbsp<a href="https://huggingface.co/datasets/m-a-p/MusicPile-sft">πŸ€— SFT Dataset</a></center>&nbsp |""")
gr.Markdown("""\
<center><font size=4>πŸ’‘Note: The music clips on this page is auto-converted from abc notations which may not be perfect,
and we recommend using better software for analysis.</center>""")
chatbot = gr.Chatbot(label='Chat-Musician', elem_classes="control-height", height=750)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
submit_btn = gr.Button("πŸš€ Submit (发送)")
empty_bin = gr.Button("🧹 Clear History (ζΈ…ι™€εŽ†ε²)")
# regen_btn = gr.Button("πŸ€”οΈ Regenerate (重试)")
gr.Examples(
examples=[
["Create music by following the alphabetic representation of the assigned musical structure and the given motif.\n'ABCA';X:1\nL:1/16\nM:2/4\nK:A\n['E2GB d2c2 B2A2', 'D2 C2E2 A2c2']"],
["Create sheet music in ABC notation from the provided text.\nAlternative title: \nThe Legacy\nKey: G\nMeter: 6/8\nNote Length: 1/8\nRhythm: Jig\nOrigin: English\nTranscription: John Chambers"],
["Develop a melody using the given chord pattern.\n'C', 'C', 'G/D', 'D', 'G', 'C', 'G', 'G', 'C', 'C', 'F', 'C/G', 'G7', 'C'"]
],
inputs=query
)
with gr.Row():
with gr.Accordion("Advanced Options:", open=False):
with gr.Row():
with gr.Column():
with gr.Row():
temperature = gr.Slider(
label="Temperature",
value=0.2,
minimum=0.0,
maximum=10.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
with gr.Column():
with gr.Row():
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=0.9,
minimum=0.0,
maximum=1,
step=0.01,
interactive=True,
info=(
"Sample from the smallest possible set of tokens whose cumulative probability "
"exceeds top_p. Set to 1 to disable and sample from all tokens."
),
)
with gr.Column():
with gr.Row():
top_k = gr.Slider(
label="Top-k",
value=40,
minimum=0.0,
maximum=200,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens β€” 0 to disable and sample from all tokens.",
)
with gr.Column():
with gr.Row():
repetition_penalty = gr.Slider(
label="Repetition Penalty",
value=1.1,
minimum=1.0,
maximum=2.0,
step=0.1,
interactive=True,
info="Penalize repetition β€” 1.0 to disable.",
)
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history], queue=False).then(
predict,
[chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id],
[chatbot, task_history],
show_progress=True,
queue=True
).then(process_and_render_abc, [chatbot, task_history, conversation_id], [chatbot])
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
gr.Markdown(
"Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce "
"factually accurate information. The model was trained on various public datasets; while great efforts "
"have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
"biased, or otherwise offensive outputs.",
elem_classes=["disclaimer"],
)
return demo
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
device_map='cuda',
torch_dtype=torch.float16
).eval()
model.generation_config = GenerationConfig.from_pretrained(
MODEL_PATH
)
app = _launch_demo(model, tokenizer)
app.queue().launch()