ffreemt commited on
Commit
8afcc54
1 Parent(s): d7fd091
Files changed (1) hide show
  1. app.py +86 -2
app.py CHANGED
@@ -2,8 +2,10 @@
2
  # pylint: disable=invalid-name, missing-function-docstring, missing-class-docstring, redefined-outer-name, broad-except
3
  import os
4
  import time
 
5
 
6
  import gradio as gr
 
7
 
8
  # from mcli import predict
9
  from huggingface_hub import hf_hub_download
@@ -17,6 +19,7 @@ if os.environ.get("MOSAICML_API_KEY") is None:
17
  raise ValueError("git environment variable must be set")
18
  # """
19
 
 
20
  def predict(x, y, timeout):
21
  logger.debug(f"{x=}, {y=}, {timeout=}")
22
 
@@ -31,6 +34,47 @@ def download_mpt_quant(destination_folder: str, repo_id: str, model_filename: st
31
  )
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  class Chat:
35
  default_system_prompt = "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
36
  system_format = "<|im_start|>system\n{}<|im_end|>\n"
@@ -120,7 +164,21 @@ def call_inf_server(prompt):
120
  # remove spl tokens from prompt
121
  spl_tokens = ["<|im_start|>", "<|im_end|>"]
122
  clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "")
123
- return response[len(clean_prompt) :] # remove the prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  except Exception as e:
125
  # assume it is our error
126
  # just wait and try one more time
@@ -142,10 +200,36 @@ _ = """full url: https://huggingface.co/TheBloke/mpt-30B-chat-GGML/blob/main/mpt
142
  repo_id = "TheBloke/mpt-30B-chat-GGML"
143
  model_filename = "mpt-30b-chat.ggmlv0.q4_1.bin"
144
  destination_folder = "models"
145
- download_mpt_quant(destination_folder, repo_id, model_filename)
 
146
 
147
  logger.info("done dl")
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  with gr.Blocks(
150
  theme=gr.themes.Soft(),
151
  css=".disclaimer {font-variant-caps: all-small-caps;}",
 
2
  # pylint: disable=invalid-name, missing-function-docstring, missing-class-docstring, redefined-outer-name, broad-except
3
  import os
4
  import time
5
+ from dataclasses import asdict, dataclass
6
 
7
  import gradio as gr
8
+ from ctransformers import AutoConfig, AutoModelForCausalLM
9
 
10
  # from mcli import predict
11
  from huggingface_hub import hf_hub_download
 
19
  raise ValueError("git environment variable must be set")
20
  # """
21
 
22
+
23
  def predict(x, y, timeout):
24
  logger.debug(f"{x=}, {y=}, {timeout=}")
25
 
 
34
  )
35
 
36
 
37
+ @dataclass
38
+ class GenerationConfig:
39
+ temperature: float
40
+ top_k: int
41
+ top_p: float
42
+ repetition_penalty: float
43
+ max_new_tokens: int
44
+ seed: int
45
+ reset: bool
46
+ stream: bool
47
+ threads: int
48
+ stop: list[str]
49
+
50
+
51
+ def format_prompt(system_prompt: str, user_prompt: str):
52
+ """format prompt based on: https://huggingface.co/spaces/mosaicml/mpt-30b-chat/blob/main/app.py"""
53
+
54
+ system_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
55
+ user_prompt = f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
56
+ assistant_prompt = f"<|im_start|>assistant\n"
57
+
58
+ return f"{system_prompt}{user_prompt}{assistant_prompt}"
59
+
60
+
61
+ def generate(
62
+ llm: AutoModelForCausalLM,
63
+ generation_config: GenerationConfig,
64
+ system_prompt: str,
65
+ user_prompt: str,
66
+ ):
67
+ """run model inference, will return a Generator if streaming is true"""
68
+
69
+ return llm(
70
+ format_prompt(
71
+ system_prompt,
72
+ user_prompt,
73
+ ),
74
+ **asdict(generation_config),
75
+ )
76
+
77
+
78
  class Chat:
79
  default_system_prompt = "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
80
  system_format = "<|im_start|>system\n{}<|im_end|>\n"
 
164
  # remove spl tokens from prompt
165
  spl_tokens = ["<|im_start|>", "<|im_end|>"]
166
  clean_prompt = prompt.replace(spl_tokens[0], "").replace(spl_tokens[1], "")
167
+
168
+ # return response[len(clean_prompt) :] # remove the prompt
169
+ try:
170
+ user_prompt = prompt
171
+ generator = generate(llm, generation_config, system_prompt, user_prompt.strip())
172
+ print(assistant_prefix, end=" ", flush=True)
173
+ for word in generator:
174
+ print(word, end="", flush=True)
175
+ print("")
176
+ response = word
177
+ except Exception as exc:
178
+ logger.error(exc)
179
+ response = f"{exc=}"
180
+ return response
181
+
182
  except Exception as e:
183
  # assume it is our error
184
  # just wait and try one more time
 
200
  repo_id = "TheBloke/mpt-30B-chat-GGML"
201
  model_filename = "mpt-30b-chat.ggmlv0.q4_1.bin"
202
  destination_folder = "models"
203
+
204
+ # download_mpt_quant(destination_folder, repo_id, model_filename)
205
 
206
  logger.info("done dl")
207
 
208
+ config = AutoConfig.from_pretrained("mosaicml/mpt-30b-chat", context_length=8192)
209
+ llm = AutoModelForCausalLM.from_pretrained(
210
+ os.path.abspath("models/mpt-30b-chat.ggmlv0.q4_1.bin"),
211
+ model_type="mpt",
212
+ config=config,
213
+ )
214
+
215
+ system_prompt = "A conversation between a user and an LLM-based AI assistant named Local Assistant. Local Assistant gives helpful and honest answers."
216
+
217
+ generation_config = GenerationConfig(
218
+ temperature=0.2,
219
+ top_k=0,
220
+ top_p=0.9,
221
+ repetition_penalty=1.0,
222
+ max_new_tokens=512, # adjust as needed
223
+ seed=42,
224
+ reset=False, # reset history (cache)
225
+ stream=True, # streaming per word/token
226
+ threads=int(os.cpu_count() / 2), # adjust for your CPU
227
+ stop=["<|im_end|>", "|<"],
228
+ )
229
+
230
+ user_prefix = "[user]: "
231
+ assistant_prefix = "[assistant]:"
232
+
233
  with gr.Blocks(
234
  theme=gr.themes.Soft(),
235
  css=".disclaimer {font-variant-caps: all-small-caps;}",