support Qwen chat model
Browse files- app_modules/llm_loader.py +39 -10
- requirements.txt +1 -0
- server.py +3 -0
app_modules/llm_loader.py
CHANGED
@@ -207,6 +207,7 @@ class LLMLoader:
|
|
207 |
0.01
|
208 |
if "gpt4all-j" in MODEL_NAME_OR_PATH
|
209 |
or "dolly" in MODEL_NAME_OR_PATH
|
|
|
210 |
else 0
|
211 |
)
|
212 |
use_fast = (
|
@@ -216,11 +217,29 @@ class LLMLoader:
|
|
216 |
)
|
217 |
padding_side = "left" # if "dolly" in MODEL_NAME_OR_PATH else None
|
218 |
|
219 |
-
config =
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
)
|
|
|
224 |
# config.attn_config["attn_impl"] = "triton"
|
225 |
# config.max_seq_len = 4096
|
226 |
config.init_device = hf_pipeline_device_type
|
@@ -360,16 +379,26 @@ class LLMLoader:
|
|
360 |
config=config,
|
361 |
trust_remote_code=True,
|
362 |
)
|
363 |
-
if
|
364 |
-
else
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
)
|
370 |
)
|
371 |
)
|
372 |
print(f"Model memory footprint: {model.get_memory_footprint()}")
|
|
|
|
|
373 |
else:
|
374 |
model = MODEL_NAME_OR_PATH
|
375 |
|
|
|
207 |
0.01
|
208 |
if "gpt4all-j" in MODEL_NAME_OR_PATH
|
209 |
or "dolly" in MODEL_NAME_OR_PATH
|
210 |
+
or "Qwen" in MODEL_NAME_OR_PATH
|
211 |
else 0
|
212 |
)
|
213 |
use_fast = (
|
|
|
217 |
)
|
218 |
padding_side = "left" # if "dolly" in MODEL_NAME_OR_PATH else None
|
219 |
|
220 |
+
config = (
|
221 |
+
AutoConfig.from_pretrained(
|
222 |
+
MODEL_NAME_OR_PATH,
|
223 |
+
trust_remote_code=True,
|
224 |
+
token=token,
|
225 |
+
fp32=hf_pipeline_device_type == "cpu",
|
226 |
+
bf16=(
|
227 |
+
hf_pipeline_device_type != "cpu"
|
228 |
+
and torch_dtype == torch.bfloat16
|
229 |
+
),
|
230 |
+
fp16=(
|
231 |
+
hf_pipeline_device_type != "cpu"
|
232 |
+
and torch_dtype != torch.bfloat16
|
233 |
+
),
|
234 |
+
)
|
235 |
+
if "Qwen" in MODEL_NAME_OR_PATH
|
236 |
+
else AutoConfig.from_pretrained(
|
237 |
+
MODEL_NAME_OR_PATH,
|
238 |
+
trust_remote_code=True,
|
239 |
+
token=token,
|
240 |
+
)
|
241 |
)
|
242 |
+
|
243 |
# config.attn_config["attn_impl"] = "triton"
|
244 |
# config.max_seq_len = 4096
|
245 |
config.init_device = hf_pipeline_device_type
|
|
|
379 |
config=config,
|
380 |
trust_remote_code=True,
|
381 |
)
|
382 |
+
if "Qwen" in MODEL_NAME_OR_PATH
|
383 |
+
else (
|
384 |
+
AutoModelForCausalLM.from_pretrained(
|
385 |
+
MODEL_NAME_OR_PATH,
|
386 |
+
config=config,
|
387 |
+
trust_remote_code=True,
|
388 |
+
)
|
389 |
+
if token is None
|
390 |
+
else AutoModelForCausalLM.from_pretrained(
|
391 |
+
MODEL_NAME_OR_PATH,
|
392 |
+
config=config,
|
393 |
+
trust_remote_code=True,
|
394 |
+
token=token,
|
395 |
+
)
|
396 |
)
|
397 |
)
|
398 |
)
|
399 |
print(f"Model memory footprint: {model.get_memory_footprint()}")
|
400 |
+
model = model.eval()
|
401 |
+
# print(f"Model memory footprint: {model.get_memory_footprint()}")
|
402 |
else:
|
403 |
model = MODEL_NAME_OR_PATH
|
404 |
|
requirements.txt
CHANGED
@@ -32,3 +32,4 @@ gevent
|
|
32 |
pydantic >= 1.10.11
|
33 |
pypdf
|
34 |
python-telegram-bot
|
|
|
|
32 |
pydantic >= 1.10.11
|
33 |
pypdf
|
34 |
python-telegram-bot
|
35 |
+
transformers_stream_generator
|
server.py
CHANGED
@@ -86,6 +86,9 @@ if __name__ == "__main__":
|
|
86 |
chat_start = timer()
|
87 |
chat_sync("What's generative AI?", chat_id="test_user")
|
88 |
chat_sync("more on finance", chat_id="test_user")
|
|
|
|
|
|
|
89 |
chat_end = timer()
|
90 |
total_time = chat_end - chat_start
|
91 |
print(f"Total time used: {total_time:.3f} s")
|
|
|
86 |
chat_start = timer()
|
87 |
chat_sync("What's generative AI?", chat_id="test_user")
|
88 |
chat_sync("more on finance", chat_id="test_user")
|
89 |
+
# chat_sync("给我讲一个年轻人奋斗创业最终取得成功的故事。", chat_id="test_user")
|
90 |
+
# chat_sync("给这个故事起一个标题", chat_id="test_user")
|
91 |
+
# chat_sync("Write the game 'snake' in python", chat_id="test_user")
|
92 |
chat_end = timer()
|
93 |
total_time = chat_end - chat_start
|
94 |
print(f"Total time used: {total_time:.3f} s")
|