ffreemt commited on
Commit
b836071
1 Parent(s): 5a1e312
Files changed (2) hide show
  1. app-.py +353 -0
  2. app.py +1 -1
app-.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Refer to https://github.com/abacaj/mpt-30B-inference/blob/main/download_model.py."""
2
+ # pylint: disable=invalid-name, missing-function-docstring, missing-class-docstring, redefined-outer-name, broad-except
3
+ import os
4
+ import time
5
+ from dataclasses import asdict, dataclass
6
+
7
+ import gradio as gr
8
+ from ctransformers import AutoConfig, AutoModelForCausalLM
9
+
10
+ # from mcli import predict
11
+ from huggingface_hub import hf_hub_download
12
+ from loguru import logger
13
+
14
+ URL = os.environ.get("URL")
15
+ _ = """
16
+ if URL is None:
17
+ raise ValueError("URL environment variable must be set")
18
+ if os.environ.get("MOSAICML_API_KEY") is None:
19
+ raise ValueError("git environment variable must be set")
20
+ # """
21
+
22
+
23
+ def predict0(prompt, bot, timeout):
24
+ logger.debug(f"{prompt=}, {bot=}, {timeout=}")
25
+ try:
26
+ user_prompt = prompt
27
+ generator = generate(llm, generation_config, system_prompt, user_prompt.strip())
28
+ print(assistant_prefix, end=" ", flush=True)
29
+ for word in generator:
30
+ print(word, end="", flush=True)
31
+ print("")
32
+ response = word
33
+ except Exception as exc:
34
+ logger.error(exc)
35
+ response = f"{exc=}"
36
+ bot = {"inputs": [response]}
37
+
38
+ return prompt, bot
39
+
40
+
41
+ def download_mpt_quant(destination_folder: str, repo_id: str, model_filename: str):
42
+ local_path = os.path.abspath(destination_folder)
43
+ return hf_hub_download(
44
+ repo_id=repo_id,
45
+ filename=model_filename,
46
+ local_dir=local_path,
47
+ local_dir_use_symlinks=True,
48
+ )
49
+
50
+
51
+ @dataclass
52
+ class GenerationConfig:
53
+ temperature: float
54
+ top_k: int
55
+ top_p: float
56
+ repetition_penalty: float
57
+ max_new_tokens: int
58
+ seed: int
59
+ reset: bool
60
+ stream: bool
61
+ threads: int
62
+ stop: list[str]
63
+
64
+
65
+ def format_prompt(system_prompt: str, user_prompt: str):
66
+ """format prompt based on: https://huggingface.co/spaces/mosaicml/mpt-30b-chat/blob/main/app.py"""
67
+
68
+ system_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
69
+ user_prompt = f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
70
+ assistant_prompt = f"<|im_start|>assistant\n"
71
+
72
+ return f"{system_prompt}{user_prompt}{assistant_prompt}"
73
+
74
+
75
+ def generate(
76
+ llm: AutoModelForCausalLM,
77
+ generation_config: GenerationConfig,
78
+ system_prompt: str,
79
+ user_prompt: str,
80
+ ):
81
+ """run model inference, will return a Generator if streaming is true"""
82
+
83
+ return llm(
84
+ format_prompt(
85
+ system_prompt,
86
+ user_prompt,
87
+ ),
88
+ **asdict(generation_config),
89
+ )
90
+
91
+
92
+ class Chat:
93
+ default_system_prompt = "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
94
+ system_format = "<|im_start|>system\n{}<|im_end|>\n"
95
+
96
+ def __init__(
97
+ self, system: str = None, user: str = None, assistant: str = None
98
+ ) -> None:
99
+ if system is not None:
100
+ self.set_system_prompt(system)
101
+ else:
102
+ self.reset_system_prompt()
103
+ self.user = user if user else "<|im_start|>user\n{}<|im_end|>\n"
104
+ self.assistant = (
105
+ assistant if assistant else "<|im_start|>assistant\n{}<|im_end|>\n"
106
+ )
107
+ self.response_prefix = self.assistant.split("{}", maxsplit=1)[0]
108
+
109
+ def set_system_prompt(self, system_prompt):
110
+ # self.system = self.system_format.format(system_prompt)
111
+ return system_prompt
112
+
113
+ def reset_system_prompt(self):
114
+ return self.set_system_prompt(self.default_system_prompt)
115
+
116
+ def history_as_formatted_str(self, system, history) -> str:
117
+ system = self.system_format.format(system)
118
+ text = system + "".join(
119
+ [
120
+ "\n".join(
121
+ [
122
+ self.user.format(item[0]),
123
+ self.assistant.format(item[1]),
124
+ ]
125
+ )
126
+ for item in history[:-1]
127
+ ]
128
+ )
129
+ text += self.user.format(history[-1][0])
130
+ text += self.response_prefix
131
+ # stopgap solution to too long sequences
132
+ if len(text) > 4500:
133
+ # delete from the middle between <|im_start|> and <|im_end|>
134
+ # find the middle ones, then expand out
135
+ start = text.find("<|im_start|>", 139)
136
+ end = text.find("<|im_end|>", 139)
137
+ while end < len(text) and len(text) > 4500:
138
+ end = text.find("<|im_end|>", end + 1)
139
+ text = text[:start] + text[end + 1 :]
140
+ if len(text) > 4500:
141
+ # the nice way didn't work, just truncate
142
+ # deleting the beginning
143
+ text = text[-4500:]
144
+
145
+ return text
146
+
147
+ def clear_history(self, history):
148
+ return []
149
+
150
+ def turn(self, user_input: str):
151
+ self.user_turn(user_input)
152
+ return self.bot_turn()
153
+
154
+ def user_turn(self, user_input: str, history):
155
+ history.append([user_input, ""])
156
+ return user_input, history
157
+
158
+ def bot_turn(self, system, history):
159
+ conversation = self.history_as_formatted_str(system, history)
160
+ assistant_response = call_inf_server(conversation)
161
+ history[-1][-1] = assistant_response
162
+ print(system)
163
+ print(history)
164
+ return "", history
165
+
166
+
167
+ def call_inf_server(prompt):
168
+ try:
169
+ response = predict(
170
+ URL,
171
+ {"inputs": [prompt], "temperature": 0.2, "top_p": 0.9, "output_len": 512},
172
+ timeout=70,
173
+ )
174
+ # print(f'prompt: {prompt}')
175
+ # print(f'len(prompt): {len(prompt)}')
176
+ response = response["outputs"][0]
177
+ # print(f'len(response): {len(response)}')
178
+ # remove spl tokens from prompt
179
+ spl_tokens = ["<|im_start|>", "<|im_end|>"]
180
+ clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "")
181
+
182
+ # return response[len(clean_prompt) :] # remove the prompt
183
+ try:
184
+ user_prompt = prompt
185
+ generator = generate(llm, generation_config, system_prompt, user_prompt.strip())
186
+ print(assistant_prefix, end=" ", flush=True)
187
+ for word in generator:
188
+ print(word, end="", flush=True)
189
+ print("")
190
+ response = word
191
+ except Exception as exc:
192
+ logger.error(exc)
193
+ response = f"{exc=}"
194
+ return response
195
+
196
+ except Exception as e:
197
+ # assume it is our error
198
+ # just wait and try one more time
199
+ print(e)
200
+ time.sleep(1)
201
+ response = predict(
202
+ URL,
203
+ {"inputs": [prompt], "temperature": 0.2, "top_p": 0.9, "output_len": 512},
204
+ timeout=70,
205
+ )
206
+ # print(response)
207
+ response = response["outputs"][0]
208
+ return response[len(prompt) :] # remove the prompt
209
+
210
+
211
+ logger.info("start dl")
212
+ _ = """full url: https://huggingface.co/TheBloke/mpt-30B-chat-GGML/blob/main/mpt-30b-chat.ggmlv0.q4_1.bin"""
213
+
214
+ repo_id = "TheBloke/mpt-30B-chat-GGML"
215
+ model_filename = "mpt-30b-chat.ggmlv0.q4_1.bin"
216
+ destination_folder = "models"
217
+
218
+ download_mpt_quant(destination_folder, repo_id, model_filename)
219
+
220
+ logger.info("done dl")
221
+
222
+ config = AutoConfig.from_pretrained("mosaicml/mpt-30b-chat", context_length=8192)
223
+ llm = AutoModelForCausalLM.from_pretrained(
224
+ os.path.abspath("models/mpt-30b-chat.ggmlv0.q4_1.bin"),
225
+ model_type="mpt",
226
+ config=config,
227
+ )
228
+
229
+ system_prompt = "A conversation between a user and an LLM-based AI assistant named Local Assistant. Local Assistant gives helpful and honest answers."
230
+
231
+ generation_config = GenerationConfig(
232
+ temperature=0.2,
233
+ top_k=0,
234
+ top_p=0.9,
235
+ repetition_penalty=1.0,
236
+ max_new_tokens=512, # adjust as needed
237
+ seed=42,
238
+ reset=False, # reset history (cache)
239
+ stream=True, # streaming per word/token
240
+ threads=int(os.cpu_count() / 2), # adjust for your CPU
241
+ stop=["<|im_end|>", "|<"],
242
+ )
243
+
244
+ user_prefix = "[user]: "
245
+ assistant_prefix = "[assistant]:"
246
+
247
+ with gr.Blocks(
248
+ theme=gr.themes.Soft(),
249
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
250
+ ) as demo:
251
+ gr.Markdown(
252
+ """<h1><center>MosaicML MPT-30B-Chat</center></h1>
253
+
254
+ This demo is of [MPT-30B-Chat](https://huggingface.co/mosaicml/mpt-30b-ch a t). It is based on [MPT-30B](https://huggingface.co/mosaicml/mpt-30b) fine-tuned on approximately 300,000 turns of high-quality conversations, and is powered by [MosaicML Inference](https://www.mosaicml.com/inference).
255
+
256
+ If you're interested in [training](https://www.mosaicml.com/training) and [deploying](https://www.mosaicml.com/inference) your own MPT or LLMs, [sign up](https://forms.mosaicml.com/demo?utm_source=huggingface&utm_medium=referral&utm_campaign=mpt-30b) for MosaicML platform.
257
+
258
+ """
259
+ )
260
+ conversation = Chat()
261
+ chatbot = gr.Chatbot().style(height=500)
262
+ with gr.Row():
263
+ with gr.Column():
264
+ msg = gr.Textbox(
265
+ label="Chat Message Box",
266
+ placeholder="Chat Message Box",
267
+ show_label=False,
268
+ ).style(container=False)
269
+ with gr.Column():
270
+ with gr.Row():
271
+ submit = gr.Button("Submit")
272
+ stop = gr.Button("Stop")
273
+ clear = gr.Button("Clear")
274
+ with gr.Row():
275
+ with gr.Accordion("Advanced Options:", open=False):
276
+ with gr.Row():
277
+ with gr.Column(scale=2):
278
+ system = gr.Textbox(
279
+ label="System Prompt",
280
+ value=Chat.default_system_prompt,
281
+ show_label=False,
282
+ ).style(container=False)
283
+ with gr.Column():
284
+ with gr.Row():
285
+ change = gr.Button("Change System Prompt")
286
+ reset = gr.Button("Reset System Prompt")
287
+ with gr.Row():
288
+ gr.Markdown(
289
+ "Disclaimer: MPT-30B can produce factually incorrect output, and should not be relied on to produce "
290
+ "factually accurate information. MPT-30B was trained on various public datasets; while great efforts "
291
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
292
+ "biased, or otherwise offensive outputs.",
293
+ elem_classes=["disclaimer"],
294
+ )
295
+ with gr.Row():
296
+ gr.Markdown(
297
+ "[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)",
298
+ elem_classes=["disclaimer"],
299
+ )
300
+
301
+ _ = """
302
+ submit_event = msg.submit(
303
+ fn=conversation.user_turn,
304
+ inputs=[msg, chatbot],
305
+ outputs=[msg, chatbot],
306
+ queue=False,
307
+ ).then(
308
+ fn=conversation.bot_turn,
309
+ inputs=[system, chatbot],
310
+ outputs=[msg, chatbot],
311
+ queue=True,
312
+ )
313
+ submit_click_event = submit.click(
314
+ fn=conversation.user_turn,
315
+ inputs=[msg, chatbot],
316
+ outputs=[msg, chatbot],
317
+ queue=False,
318
+ ).then(
319
+ # fn=conversation.bot_turn,
320
+ inputs=[system, chatbot],
321
+ outputs=[msg, chatbot],
322
+ queue=True,
323
+ )
324
+
325
+ stop.click(
326
+ fn=None,
327
+ inputs=None,
328
+ outputs=None,
329
+ cancels=[submit_event, submit_click_event],
330
+ queue=False,
331
+ )
332
+ clear.click(lambda: None, None, chatbot, queue=False).then(
333
+ fn=conversation.clear_history,
334
+ inputs=[chatbot],
335
+ outputs=[chatbot],
336
+ queue=False,
337
+ )
338
+ change.click(
339
+ fn=conversation.set_system_prompt,
340
+ inputs=[system],
341
+ outputs=[system],
342
+ queue=False,
343
+ )
344
+ reset.click(
345
+ fn=conversation.reset_system_prompt,
346
+ inputs=[],
347
+ outputs=[system],
348
+ queue=False,
349
+ )
350
+ # """
351
+
352
+
353
+ demo.queue(max_size=36, concurrency_count=14).launch(debug=True)
app.py CHANGED
@@ -321,7 +321,6 @@ with gr.Blocks(
321
  outputs=[msg, chatbot],
322
  queue=True,
323
  )
324
- # """
325
 
326
  stop.click(
327
  fn=None,
@@ -348,6 +347,7 @@ with gr.Blocks(
348
  outputs=[system],
349
  queue=False,
350
  )
 
351
 
352
 
353
  demo.queue(max_size=36, concurrency_count=14).launch(debug=True)
 
321
  outputs=[msg, chatbot],
322
  queue=True,
323
  )
 
324
 
325
  stop.click(
326
  fn=None,
 
347
  outputs=[system],
348
  queue=False,
349
  )
350
+ # """
351
 
352
 
353
  demo.queue(max_size=36, concurrency_count=14).launch(debug=True)