hf-lin commited on
Commit
7bd9bb1
β€’
1 Parent(s): 276fd1d

update app

Browse files
Files changed (3) hide show
  1. Dockerfile +15 -0
  2. app.py +275 -26
  3. requirements.txt +0 -1
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+ RUN apt update
9
+ RUN apt install fuse libfuse2
10
+
11
+ COPY . .
12
+ RUN chmod +x MuseScore-4.1.1.232071203-x86_64.AppImage
13
+ RUN mkdir -m 700 flagged
14
+
15
+ CMD ["uvicorn", "app:gradio_app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,50 +1,299 @@
1
- import gradio as gr
2
  import os
3
  import re
4
- import subprocess
5
  import time
6
- from symusic import Score, Synthesizer
7
- import torchaudio
 
 
8
  import torch
 
 
 
 
 
9
 
10
- # for rendering abc notation
11
  os.environ['QT_QPA_PLATFORM']='offscreen'
12
- os.system("apt-get install fuse libfuse2")
13
- os.system("chmod +x MuseScore-4.1.1.232071203-x86_64.AppImage")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- default_abc = 'X:1\nL:1/8\nM:2/4\nK:G\n|:"G" G>A Bc | dB dB |"C" ce ce |"D7" dB A2 |"G" G>A Bc | dB dB |"Am" cA"D7" FA |"G" AG G2 :: \n"Em" g2"D" f>e | de Bd |"C" ce ce |"D7" dB A2 |"G" g2"D" f>e | de Bd |"Am" cA"D7" FA |"G" AG G2 :| \n'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def parse_abc_notation(text='', conversation_id='debug'):
18
- # os.makedirs(f"tmp/", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
19
  ts = time.time()
 
20
  abc_pattern = r'(X:\d+\n(?:[^\n]*\n)+)'
21
  abc_notation = re.findall(abc_pattern, text+'\n')
22
  print(f'extract abc block: {abc_notation}')
23
  if abc_notation:
24
- # Convert ABC to midi
25
  s = Score.from_abc(abc_notation[0])
26
- wav_file = f'{ts}.mp3'
27
  audio = Synthesizer().render(s, stereo=True)
28
- torchaudio.save(wav_file, torch.FloatTensor(audio), 44100)
 
29
 
30
  # Convert abc notation to SVG
31
- tmp_midi = f'{ts}.mid'
32
  s.dump_midi(tmp_midi)
33
- svg_file = f'{ts}.svg'
34
  subprocess.run(["./MuseScore-4.1.1.232071203-x86_64.AppImage", "-f", "-o", svg_file, tmp_midi])
35
- return None, wav_file
 
36
  else:
37
- return None, None
38
 
39
 
40
- if __name__ == "__main__":
41
 
42
- gradio_app = gr.Interface(
43
- parse_abc_notation,
44
- inputs=["text"],
45
- outputs=[gr.Image(label="svg"), gr.Audio(label="audio")],
46
- title="ABC notation parse",
47
- examples=[default_abc]
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- gradio_app.launch()
 
 
1
  import os
2
  import re
3
+ import copy
4
  import time
5
+ import logging
6
+ import subprocess
7
+ from uuid import uuid4
8
+ import gradio as gr
9
  import torch
10
+ import torchaudio
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from transformers.generation import GenerationConfig
13
+ from symusic import Score, Synthesizer
14
+ import spaces
15
 
 
16
  os.environ['QT_QPA_PLATFORM']='offscreen'
17
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
18
+ torch.backends.cuda.enable_flash_sdp(False)
19
+
20
+ # log_dir
21
+ os.makedirs("logs", exist_ok=True)
22
+ os.makedirs("tmp", exist_ok=True)
23
+ logging.basicConfig(
24
+ filename=f'logs/chatmusician_server_{time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(time.time()))}.log',
25
+ level=logging.WARNING,
26
+ format='%(asctime)s [%(levelname)s]: %(message)s',
27
+ datefmt='%Y-%m-%d %H:%M:%S'
28
+ )
29
+
30
+ MODEL_PATH = 'm-a-p/ChatMusician'
31
+
32
+
33
+ def get_uuid():
34
+ return str(uuid4())
35
+
36
+
37
+ # todo
38
+ def log_conversation(conversation_id, history, messages, response, generate_kwargs):
39
+ timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(time.time()))
40
+ data = {
41
+ "conversation_id": conversation_id,
42
+ "timestamp": timestamp,
43
+ "history": history,
44
+ "messages": messages,
45
+ "response": response,
46
+ "generate_kwargs": generate_kwargs,
47
+ }
48
+ logging.critical(f"{data}")
49
+
50
 
51
+ def _parse_text(text):
52
+ lines = text.split("\n")
53
+ lines = [line for line in lines if line != ""]
54
+ count = 0
55
+ for i, line in enumerate(lines):
56
+ if "```" in line:
57
+ count += 1
58
+ items = line.split("`")
59
+ if count % 2 == 1:
60
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
61
+ else:
62
+ lines[i] = f"<br></code></pre>"
63
+ else:
64
+ if i > 0:
65
+ if count % 2 == 1:
66
+ line = line.replace("`", r"\`")
67
+ line = line.replace("<", "&lt;")
68
+ line = line.replace(">", "&gt;")
69
+ line = line.replace(" ", "&nbsp;")
70
+ line = line.replace("*", "&ast;")
71
+ line = line.replace("_", "&lowbar;")
72
+ line = line.replace("-", "&#45;")
73
+ line = line.replace(".", "&#46;")
74
+ line = line.replace("!", "&#33;")
75
+ line = line.replace("(", "&#40;")
76
+ line = line.replace(")", "&#41;")
77
+ line = line.replace("$", "&#36;")
78
+ lines[i] = "<br>" + line
79
+ text = "".join(lines)
80
+ return text
81
 
82
+
83
+ def convert_history_to_text(task_history):
84
+ history_cp = copy.deepcopy(task_history)
85
+ text = "".join(
86
+ [f"Human: {item[0]} </s> Assistant: {item[1]} </s> " for item in history_cp[:-1] if item[0]]
87
+ )
88
+ text += f"Human: {history_cp[-1][0]} </s> Assistant: "
89
+ return text
90
+
91
+ # todo
92
+ def postprocess_abc(text, conversation_id):
93
+ os.makedirs(f"tmp/{conversation_id}", exist_ok=True)
94
  ts = time.time()
95
+
96
  abc_pattern = r'(X:\d+\n(?:[^\n]*\n)+)'
97
  abc_notation = re.findall(abc_pattern, text+'\n')
98
  print(f'extract abc block: {abc_notation}')
99
  if abc_notation:
100
+ # render ABC as audio
101
  s = Score.from_abc(abc_notation[0])
 
102
  audio = Synthesizer().render(s, stereo=True)
103
+ audio_file = f'tmp/{conversation_id}/{ts}.mp3'
104
+ torchaudio.save(audio_file, torch.FloatTensor(audio), 44100)
105
 
106
  # Convert abc notation to SVG
107
+ tmp_midi = f'tmp/{conversation_id}/{ts}.mid'
108
  s.dump_midi(tmp_midi)
109
+ svg_file = f'tmp/{conversation_id}/{ts}.svg'
110
  subprocess.run(["./MuseScore-4.1.1.232071203-x86_64.AppImage", "-f", "-o", svg_file, tmp_midi])
111
+
112
+ return svg_file, audio_file
113
  else:
114
+ return None, None
115
 
116
 
117
+ def _launch_demo(model, tokenizer):
118
 
119
+ @spaces.GPU
120
+ def predict(_chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id):
121
+ query = task_history[-1][0]
122
+ print("User: " + _parse_text(query))
123
+ # model generation
124
+ messages = convert_history_to_text(task_history)
125
+ inputs = tokenizer(messages, return_tensors="pt", add_special_tokens=False)
126
+ generation_config = GenerationConfig(
127
+ temperature=float(temperature),
128
+ top_p = float(top_p),
129
+ top_k = top_k,
130
+ repetition_penalty = float(repetition_penalty),
131
+ max_new_tokens=1536,
132
+ min_new_tokens=5,
133
+ do_sample=True,
134
+ num_beams=1,
135
+ num_return_sequences=1
136
+ )
137
+ response = model.generate(
138
+ input_ids=inputs["input_ids"].to(model.device),
139
+ attention_mask=inputs['attention_mask'].to(model.device),
140
+ eos_token_id=tokenizer.eos_token_id,
141
+ pad_token_id=tokenizer.eos_token_id,
142
+ generation_config=generation_config,
143
+ )
144
+ response = tokenizer.decode(response[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
145
+ _chatbot[-1] = (_parse_text(query), _parse_text(response))
146
+ task_history[-1] = (_parse_text(query), response)
147
+ # log
148
+ log_conversation(conversation_id, task_history, messages, _chatbot[-1][1], generation_config.to_json_string())
149
+ return _chatbot, task_history
150
+
151
+ def process_and_render_abc(_chatbot, task_history, conversation_id):
152
+ svg_file, wav_file = None, None
153
+ try:
154
+ svg_file, wav_file = postprocess_abc(task_history[-1][1], conversation_id)
155
+ except Exception as e:
156
+ logging.error(e)
157
+
158
+ if svg_file and wav_file:
159
+ if os.path.exists(svg_file) and os.path.exists(wav_file):
160
+ logging.critical(f"generate: svg: {svg_file} wav: {wav_file}")
161
+ print(f"generate:\n{svg_file}\n{wav_file}")
162
+ _chatbot.append((None, (str(wav_file),)))
163
+ _chatbot.append((None, (str(svg_file),)))
164
+ else:
165
+ logging.error(f"fail to convert: {svg_file[:-4]}.musicxml")
166
+ return _chatbot
167
+
168
+ def add_text(history, task_history, text):
169
+ history = history + [(_parse_text(text), None)]
170
+ task_history = task_history + [(text, None)]
171
+ return history, task_history, ""
172
+
173
+ def reset_user_input():
174
+ return gr.update(value="")
175
+
176
+ def reset_state(task_history):
177
+ task_history.clear()
178
+ return []
179
+
180
+ with gr.Blocks() as demo:
181
+ conversation_id = gr.State(get_uuid)
182
+ gr.Markdown(
183
+ f"<h1><center>Chat Musician</center></h1>"
184
+ )
185
+
186
+ gr.Markdown("""\
187
+ <center><font size=4>Chat-Musician <a href="https://huggingface.co/m-a-p/ChatMusician-v1-sft-78k">πŸ€—</a>&nbsp |
188
+ &nbsp<a href="https://github.com/a43992899/Chat-Musician">Github</a></center>""")
189
+ gr.Markdown("""\
190
+ <center><font size=4>πŸ’‘Note: The music clips on this page is auto-converted from abc notations which may not be perfect,
191
+ and we recommend using better software for analysis.</center>""")
192
+
193
+ chatbot = gr.Chatbot(label='Chat-Musician', elem_classes="control-height", height=750)
194
+ query = gr.Textbox(lines=2, label='Input')
195
+ task_history = gr.State([])
196
+
197
+ with gr.Row():
198
+ submit_btn = gr.Button("πŸš€ Submit (发送)")
199
+ empty_bin = gr.Button("🧹 Clear History (ζΈ…ι™€εŽ†ε²)")
200
+ # regen_btn = gr.Button("πŸ€”οΈ Regenerate (重试)")
201
+ gr.Examples(
202
+ examples=[
203
+ ["Create music by following the alphabetic representation of the assigned musical structure and the given motif.\n'ABCA';X:1\nL:1/16\nM:2/4\nK:A\n['E2GB d2c2 B2A2', 'D2 C2E2 A2c2']"],
204
+ ["Create sheet music in ABC notation from the provided text.\nAlternative title: \nThe Legacy\nKey: G\nMeter: 6/8\nNote Length: 1/8\nRhythm: Jig\nOrigin: English\nTranscription: John Chambers"],
205
+ ["Develop a melody using the given chord pattern.\n'C', 'C', 'G/D', 'D', 'G', 'C', 'G', 'G', 'C', 'C', 'F', 'C/G', 'G7', 'C'"]
206
+ ],
207
+ inputs=query
208
+ )
209
+ with gr.Row():
210
+ with gr.Accordion("Advanced Options:", open=False):
211
+ with gr.Row():
212
+ with gr.Column():
213
+ with gr.Row():
214
+ temperature = gr.Slider(
215
+ label="Temperature",
216
+ value=0.2,
217
+ minimum=0.0,
218
+ maximum=10.0,
219
+ step=0.1,
220
+ interactive=True,
221
+ info="Higher values produce more diverse outputs",
222
+ )
223
+ with gr.Column():
224
+ with gr.Row():
225
+ top_p = gr.Slider(
226
+ label="Top-p (nucleus sampling)",
227
+ value=0.9,
228
+ minimum=0.0,
229
+ maximum=1,
230
+ step=0.01,
231
+ interactive=True,
232
+ info=(
233
+ "Sample from the smallest possible set of tokens whose cumulative probability "
234
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
235
+ ),
236
+ )
237
+ with gr.Column():
238
+ with gr.Row():
239
+ top_k = gr.Slider(
240
+ label="Top-k",
241
+ value=40,
242
+ minimum=0.0,
243
+ maximum=200,
244
+ step=1,
245
+ interactive=True,
246
+ info="Sample from a shortlist of top-k tokens β€” 0 to disable and sample from all tokens.",
247
+ )
248
+ with gr.Column():
249
+ with gr.Row():
250
+ repetition_penalty = gr.Slider(
251
+ label="Repetition Penalty",
252
+ value=1.1,
253
+ minimum=1.0,
254
+ maximum=2.0,
255
+ step=0.1,
256
+ interactive=True,
257
+ info="Penalize repetition β€” 1.0 to disable.",
258
+ )
259
+
260
+ submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history], queue=False).then(
261
+ predict,
262
+ [chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id],
263
+ [chatbot, task_history],
264
+ show_progress=True,
265
+ queue=True
266
+ ).then(process_and_render_abc, [chatbot, task_history, conversation_id], [chatbot])
267
+ submit_btn.click(reset_user_input, [], [query])
268
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
269
+
270
+ gr.Markdown(
271
+ "Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce "
272
+ "factually accurate information. The model was trained on various public datasets; while great efforts "
273
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
274
+ "biased, or otherwise offensive outputs.",
275
+ elem_classes=["disclaimer"],
276
+ )
277
+
278
+ return demo
279
+
280
+
281
+ tokenizer = AutoTokenizer.from_pretrained(
282
+ MODEL_PATH, trust_remote_code=True, resume_download=True,
283
+ )
284
+
285
+ model = AutoModelForCausalLM.from_pretrained(
286
+ MODEL_PATH,
287
+ device_map='auto',
288
+ torch_dtype=torch.float,
289
+ trust_remote_code=True,
290
+ resume_download=True,
291
+ ).eval()
292
+
293
+ model.generation_config = GenerationConfig.from_pretrained(
294
+ MODEL_PATH, trust_remote_code=True, resume_download=True,
295
+ )
296
+
297
+ app = _launch_demo(model, tokenizer)
298
 
299
+ app.queue().launch()
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  gradio==4.19.2
2
- music21==9.1.0
3
  symusic==0.4.2
4
  torch==2.2.1
5
  torchaudio==2.2.1
 
1
  gradio==4.19.2
 
2
  symusic==0.4.2
3
  torch==2.2.1
4
  torchaudio==2.2.1