inflaton commited on
Commit
b5f591e
1 Parent(s): 2732905

support Qwen chat model

Browse files
Files changed (3) hide show
  1. app_modules/llm_loader.py +39 -10
  2. requirements.txt +1 -0
  3. server.py +3 -0
app_modules/llm_loader.py CHANGED
@@ -207,6 +207,7 @@ class LLMLoader:
207
  0.01
208
  if "gpt4all-j" in MODEL_NAME_OR_PATH
209
  or "dolly" in MODEL_NAME_OR_PATH
 
210
  else 0
211
  )
212
  use_fast = (
@@ -216,11 +217,29 @@ class LLMLoader:
216
  )
217
  padding_side = "left" # if "dolly" in MODEL_NAME_OR_PATH else None
218
 
219
- config = AutoConfig.from_pretrained(
220
- MODEL_NAME_OR_PATH,
221
- trust_remote_code=True,
222
- token=token,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  )
 
224
  # config.attn_config["attn_impl"] = "triton"
225
  # config.max_seq_len = 4096
226
  config.init_device = hf_pipeline_device_type
@@ -360,16 +379,26 @@ class LLMLoader:
360
  config=config,
361
  trust_remote_code=True,
362
  )
363
- if token is None
364
- else AutoModelForCausalLM.from_pretrained(
365
- MODEL_NAME_OR_PATH,
366
- config=config,
367
- trust_remote_code=True,
368
- token=token,
 
 
 
 
 
 
 
 
369
  )
370
  )
371
  )
372
  print(f"Model memory footprint: {model.get_memory_footprint()}")
 
 
373
  else:
374
  model = MODEL_NAME_OR_PATH
375
 
 
207
  0.01
208
  if "gpt4all-j" in MODEL_NAME_OR_PATH
209
  or "dolly" in MODEL_NAME_OR_PATH
210
+ or "Qwen" in MODEL_NAME_OR_PATH
211
  else 0
212
  )
213
  use_fast = (
 
217
  )
218
  padding_side = "left" # if "dolly" in MODEL_NAME_OR_PATH else None
219
 
220
+ config = (
221
+ AutoConfig.from_pretrained(
222
+ MODEL_NAME_OR_PATH,
223
+ trust_remote_code=True,
224
+ token=token,
225
+ fp32=hf_pipeline_device_type == "cpu",
226
+ bf16=(
227
+ hf_pipeline_device_type != "cpu"
228
+ and torch_dtype == torch.bfloat16
229
+ ),
230
+ fp16=(
231
+ hf_pipeline_device_type != "cpu"
232
+ and torch_dtype != torch.bfloat16
233
+ ),
234
+ )
235
+ if "Qwen" in MODEL_NAME_OR_PATH
236
+ else AutoConfig.from_pretrained(
237
+ MODEL_NAME_OR_PATH,
238
+ trust_remote_code=True,
239
+ token=token,
240
+ )
241
  )
242
+
243
  # config.attn_config["attn_impl"] = "triton"
244
  # config.max_seq_len = 4096
245
  config.init_device = hf_pipeline_device_type
 
379
  config=config,
380
  trust_remote_code=True,
381
  )
382
+ if "Qwen" in MODEL_NAME_OR_PATH
383
+ else (
384
+ AutoModelForCausalLM.from_pretrained(
385
+ MODEL_NAME_OR_PATH,
386
+ config=config,
387
+ trust_remote_code=True,
388
+ )
389
+ if token is None
390
+ else AutoModelForCausalLM.from_pretrained(
391
+ MODEL_NAME_OR_PATH,
392
+ config=config,
393
+ trust_remote_code=True,
394
+ token=token,
395
+ )
396
  )
397
  )
398
  )
399
  print(f"Model memory footprint: {model.get_memory_footprint()}")
400
+ model = model.eval()
401
+ # print(f"Model memory footprint: {model.get_memory_footprint()}")
402
  else:
403
  model = MODEL_NAME_OR_PATH
404
 
requirements.txt CHANGED
@@ -32,3 +32,4 @@ gevent
32
  pydantic >= 1.10.11
33
  pypdf
34
  python-telegram-bot
 
 
32
  pydantic >= 1.10.11
33
  pypdf
34
  python-telegram-bot
35
+ transformers_stream_generator
server.py CHANGED
@@ -86,6 +86,9 @@ if __name__ == "__main__":
86
  chat_start = timer()
87
  chat_sync("What's generative AI?", chat_id="test_user")
88
  chat_sync("more on finance", chat_id="test_user")
 
 
 
89
  chat_end = timer()
90
  total_time = chat_end - chat_start
91
  print(f"Total time used: {total_time:.3f} s")
 
86
  chat_start = timer()
87
  chat_sync("What's generative AI?", chat_id="test_user")
88
  chat_sync("more on finance", chat_id="test_user")
89
+ # chat_sync("给我讲一个年轻人奋斗创业最终取得成功的故事。", chat_id="test_user")
90
+ # chat_sync("给这个故事起一个标题", chat_id="test_user")
91
+ # chat_sync("Write the game 'snake' in python", chat_id="test_user")
92
  chat_end = timer()
93
  total_time = chat_end - chat_start
94
  print(f"Total time used: {total_time:.3f} s")