File size: 12,393 Bytes
2158a6f
 
7bd9bb1
2158a6f
7bd9bb1
 
 
 
2158a6f
7bd9bb1
 
 
2158a6f
39d0b5a
eac8ee9
7bd9bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2158a6f
7bd9bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2158a6f
7bd9bb1
 
 
 
 
 
 
 
 
 
 
 
2158a6f
 
 
 
9644bee
 
 
 
 
 
7bd9bb1
9644bee
 
7bd9bb1
9644bee
 
 
 
2158a6f
9644bee
2158a6f
 
7bd9bb1
2158a6f
7bd9bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19d4c67
7bd9bb1
 
c8c4c6a
 
 
 
 
 
7bd9bb1
9644bee
7bd9bb1
 
9644bee
7bd9bb1
 
 
 
 
 
 
 
 
 
9644bee
7bd9bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc9a57c
7bd9bb1
 
 
 
fc9a57c
 
7bd9bb1
 
 
fc9a57c
7bd9bb1
 
 
2158a6f
7bd9bb1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
import os
import re
import copy
import time
import logging
import subprocess
from uuid import uuid4
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import spaces

os.environ['QT_QPA_PLATFORM']='offscreen'

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

# log_dir
os.makedirs("logs", exist_ok=True)
os.makedirs("tmp", exist_ok=True)
logging.basicConfig(
    filename=f'logs/chatmusician_server_{time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(time.time()))}.log',
    level=logging.WARNING,
    format='%(asctime)s [%(levelname)s]: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

MODEL_PATH = 'm-a-p/ChatMusician'


def get_uuid():
    return str(uuid4())


# todo
def log_conversation(conversation_id, history, messages, response, generate_kwargs):
    timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(time.time()))
    data = {
        "conversation_id": conversation_id,
        "timestamp": timestamp,
        "history": history,
        "messages": messages,
        "response": response,
        "generate_kwargs": generate_kwargs,
    }
    logging.critical(f"{data}")


def _parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split("`")
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f"<br></code></pre>"
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", r"\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text


def convert_history_to_text(task_history):
    history_cp = copy.deepcopy(task_history)
    text = "".join(
        [f"Human: {item[0]} </s> Assistant: {item[1]} </s> " for item in history_cp[:-1] if item[0]]
    )
    text += f"Human: {history_cp[-1][0]} </s> Assistant: "
    return text

# todo
def postprocess_abc(text, conversation_id):
    os.makedirs(f"tmp/{conversation_id}", exist_ok=True)
    abc_pattern = r'(X:\d+\n(?:[^\n]*\n)+)'
    abc_notation = re.findall(abc_pattern, text+'\n')
    print(f'extract abc block: {abc_notation}')
    if abc_notation:
        ts = time.time()
        # Write the ABC text to a temporary file
        tmp_abc = f"tmp/{conversation_id}/{ts}.abc"
        with open(tmp_abc, "w") as abc_file:
            abc_file.write(abc_notation[0])
        # Convert abc notation to midi
        tmp_midi = f'tmp/{conversation_id}/{ts}.mid'
        subprocess.run(["abc2midi", str(tmp_abc), "-o", tmp_midi])
        # Convert abc notation to SVG
        svg_file = f'tmp/{conversation_id}/{ts}.svg'
        audio_file = f'tmp/{conversation_id}/{ts}.mp3'
        subprocess.run(["musescore", "-o", svg_file, tmp_midi], capture_output=True, text=True)
        subprocess.run(["musescore","-o", audio_file, tmp_midi])
        return svg_file.replace(".svg", "-1.svg"), audio_file 
    else:
        return None, None 


def _launch_demo(model, tokenizer):

    @spaces.GPU
    def predict(_chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id):
        query = task_history[-1][0]
        print("User: " + _parse_text(query))
        # model generation
        messages = convert_history_to_text(task_history)
        inputs = tokenizer(messages, return_tensors="pt", add_special_tokens=False)
        generation_config = GenerationConfig(
            temperature=float(temperature), 
            top_p = float(top_p), 
            top_k = top_k, 
            repetition_penalty = float(repetition_penalty),
            max_new_tokens=1536,
            min_new_tokens=5,
            do_sample=True,
            num_beams=1,
            num_return_sequences=1
        )
        response = model.generate(
                input_ids=inputs["input_ids"].to(model.device),
                attention_mask=inputs['attention_mask'].to(model.device),
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.eos_token_id,
                generation_config=generation_config,
                )
        response = tokenizer.decode(response[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
        _chatbot[-1] = (_parse_text(query), _parse_text(response))
        task_history[-1] = (_parse_text(query), response)
        # log
        log_conversation(conversation_id, task_history, messages, _chatbot[-1][1], generation_config.to_json_string())
        return _chatbot, task_history
        
    def process_and_render_abc(_chatbot, task_history, conversation_id):
        svg_file, wav_file = None, None
        try:
            svg_file, wav_file = postprocess_abc(task_history[-1][1], conversation_id)
        except Exception as e:
            logging.error(e)
        
        if svg_file and wav_file:
            if os.path.exists(svg_file) and os.path.exists(wav_file):
                logging.critical(f"generate: svg: {svg_file} wav: {wav_file}")
                print(f"generate:\n{svg_file}\n{wav_file}")
                _chatbot.append((None, (str(wav_file),)))
                _chatbot.append((None, (str(svg_file),)))
            else:
                logging.error(f"fail to convert: {svg_file[:-4]}.musicxml")
        return _chatbot 
    
    def add_text(history, task_history, text):
        history = history + [(_parse_text(text), None)]
        task_history = task_history + [(text, None)]
        return history, task_history, ""
    
    def reset_user_input():
        return gr.update(value="")

    def reset_state(task_history):
        task_history.clear()
        return []

    with gr.Blocks() as demo:
        conversation_id = gr.State(get_uuid)
        gr.Markdown(
            """<h1><center>Chat Musician</center></h1>"""
        )
        gr.Markdown("""\
        <center><font size=4><a href="https://ezmonyi.github.io/ChatMusician/">🌐 DemoPage</a>&nbsp |
        &nbsp<a href="https://github.com/hf-lin/ChatMusician">πŸ’» Github</a>&nbsp |
        &nbsp<a href="http://arxiv.org/abs/2402.16153">πŸ“– arXiv</a>&nbsp |
        &nbsp<a href="https://huggingface.co/datasets/m-a-p/MusicTheoryBench">πŸ€— Benchmark</a>&nbsp |
        &nbsp<a href="https://huggingface.co/datasets/m-a-p/MusicPile">πŸ€— Pretrain Dataset</a>&nbsp |
        &nbsp<a href="https://huggingface.co/datasets/m-a-p/MusicPile-sft">πŸ€— SFT Dataset</a></center>""")
        gr.Markdown("""\
    <center><font size=4>πŸ’‘Note: The music clips on this page is auto-converted using musescore2 which may not be perfect, 
    and we recommend using better software for analysis.</center>""")

        chatbot = gr.Chatbot(label='ChatMusician', elem_classes="control-height", height=750)
        query = gr.Textbox(lines=2, label='Input')
        task_history = gr.State([])
        
        with gr.Row():
            submit_btn = gr.Button("πŸš€ Submit (发送)")
            empty_bin = gr.Button("🧹 Clear History (ζΈ…ι™€εŽ†ε²)")
            # regen_btn = gr.Button("πŸ€”οΈ Regenerate (重试)")
        gr.Examples(
            examples=[
                    ["Create music by following the alphabetic representation of the assigned musical structure and the given motif.\n'ABCA';X:1\nL:1/16\nM:2/4\nK:A\n['E2GB d2c2 B2A2', 'D2 C2E2 A2c2']"],
                    ["Develop a melody using the given chord pattern.\n'C', 'C', 'G/D', 'D', 'G', 'C', 'G', 'G', 'C', 'C', 'F', 'C/G', 'G7', 'C'"],
                    ["Create sheet music in ABC notation from the provided text.\nAlternative title: \nThe Legacy\nKey: G\nMeter: 6/8\nNote Length: 1/8\nRhythm: Jig\nOrigin: English\nTranscription: John Chambers"],
                ],
            inputs=query
            )   
        with gr.Row():
            with gr.Accordion("Advanced Options:", open=False):
                with gr.Row():
                    with gr.Column():
                        with gr.Row():
                            temperature = gr.Slider(
                                label="Temperature",
                                value=0.2,
                                minimum=0.0,
                                maximum=10.0,
                                step=0.1,
                                interactive=True,
                                info="Higher values produce more diverse outputs",
                            )
                    with gr.Column():
                        with gr.Row():
                            top_p = gr.Slider(
                                label="Top-p (nucleus sampling)",
                                value=0.9,
                                minimum=0.0,
                                maximum=1,
                                step=0.01,
                                interactive=True,
                                info=(
                                    "Sample from the smallest possible set of tokens whose cumulative probability "
                                    "exceeds top_p. Set to 1 to disable and sample from all tokens."
                                ),
                            )
                    with gr.Column():
                        with gr.Row():
                            top_k = gr.Slider(
                                label="Top-k",
                                value=40,
                                minimum=0.0,
                                maximum=200,
                                step=1,
                                interactive=True,
                                info="Sample from a shortlist of top-k tokens β€” 0 to disable and sample from all tokens.",
                            )
                    with gr.Column():
                        with gr.Row():
                            repetition_penalty = gr.Slider(
                                label="Repetition Penalty",
                                value=1.1,
                                minimum=1.0,
                                maximum=2.0,
                                step=0.1,
                                interactive=True,
                                info="Penalize repetition β€” 1.0 to disable.",
                            )

        submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history], queue=False).then(
            predict, 
            [chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id], 
            [chatbot, task_history], 
            show_progress=True,
            queue=True
        ).then(process_and_render_abc, [chatbot, task_history, conversation_id], [chatbot])
        submit_btn.click(reset_user_input, [], [query])
        empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
   
        gr.Markdown(
                "Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce "
                "factually accurate information. The model was trained on various public datasets; while great efforts "
                "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
                "biased, or otherwise offensive outputs.",
                elem_classes=["disclaimer"],
            )

    return demo
    

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map='cuda',
    torch_dtype=torch.float16
).eval()

model.generation_config = GenerationConfig.from_pretrained(
    MODEL_PATH
)

app = _launch_demo(model, tokenizer)

app.queue().launch()