moyanwang commited on
Commit
5f232f5
·
1 Parent(s): 05581a1

update demo

Browse files
Files changed (1) hide show
  1. demo.py +10 -24
demo.py CHANGED
@@ -1,47 +1,33 @@
1
-
2
 
3
  from transformers import AutoTokenizer
4
  from faster_chat_glm import GLM6B, FasterChatGLM
5
 
6
 
7
- MAX_OUT_LEN = 50
8
- BATCH_SIZE = 8
9
- USE_CACHE = True
10
-
11
- print("Prepare config and inputs....")
12
  chatglm6b_dir = './models'
13
  tokenizer = AutoTokenizer.from_pretrained(chatglm6b_dir, trust_remote_code=True)
14
-
15
- input_str = ["音乐推荐应该考虑哪些因素?帮我写一篇不少于800字的方案。 ", ] * BATCH_SIZE
16
  inputs = tokenizer(input_str, return_tensors="pt", padding=True)
17
- input_ids = inputs.input_ids
18
- input_ids = input_ids.to('cuda:0')
19
- print(input_ids.shape)
20
 
21
 
22
- print('Loading faster model...')
23
- if USE_CACHE:
24
- plan_path = f'./models/glm6b-kv-cache-dy-bs{BATCH_SIZE}.ftm'
25
- else:
26
- plan_path = f'./models/glm6b-bs{BATCH_SIZE}.ftm'
27
-
28
  # kernel for chat model.
29
  kernel = GLM6B(plan_path=plan_path,
30
- batch_size=BATCH_SIZE,
31
  num_beams=1,
32
- use_cache=USE_CACHE,
33
  num_heads=32,
34
  emb_size_per_heads=128,
35
  decoder_layers=28,
36
  vocab_size=150528,
37
  max_seq_len=MAX_OUT_LEN)
38
- print("test")
39
- chat = FasterChatGLM(model_dir=chatglm6b_dir, kernel=kernel).half().cuda()
40
 
41
  # generate
42
  sample_output = chat.generate(inputs=input_ids, max_length=MAX_OUT_LEN)
43
  # de-tokenize model output to text
44
  res = tokenizer.decode(sample_output[0], skip_special_tokens=True)
45
- print(res)
46
- res = tokenizer.decode(sample_output[BATCH_SIZE-1], skip_special_tokens=True)
47
- print(res)
 
1
+ # coding=utf-8
2
 
3
  from transformers import AutoTokenizer
4
  from faster_chat_glm import GLM6B, FasterChatGLM
5
 
6
 
7
+ MAX_OUT_LEN = 100
 
 
 
 
8
  chatglm6b_dir = './models'
9
  tokenizer = AutoTokenizer.from_pretrained(chatglm6b_dir, trust_remote_code=True)
10
+ input_str = ["为什么我们需要对深度学习模型加速?", ]
 
11
  inputs = tokenizer(input_str, return_tensors="pt", padding=True)
12
+ input_ids = inputs.input_ids.to('cuda:0')
 
 
13
 
14
 
15
+ plan_path = './models/glm6b-bs8.ftm'
 
 
 
 
 
16
  # kernel for chat model.
17
  kernel = GLM6B(plan_path=plan_path,
18
+ batch_size=1,
19
  num_beams=1,
20
+ use_cache=True,
21
  num_heads=32,
22
  emb_size_per_heads=128,
23
  decoder_layers=28,
24
  vocab_size=150528,
25
  max_seq_len=MAX_OUT_LEN)
26
+
27
+ chat = FasterChatGLM(model_dir="./models", kernel=kernel).half().cuda()
28
 
29
  # generate
30
  sample_output = chat.generate(inputs=input_ids, max_length=MAX_OUT_LEN)
31
  # de-tokenize model output to text
32
  res = tokenizer.decode(sample_output[0], skip_special_tokens=True)
33
+ print(res)