quincyqiang commited on
Commit
3ba55ac
1 Parent(s): d69b65e
Files changed (1) hide show
  1. inference.py +49 -0
inference.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
+ """
4
+ 单轮对话,不具有对话历史的记忆功能
5
+ """
6
+
7
+
8
+ def main():
9
+ model_name = 'golaxy/gogpt2-7b'
10
+
11
+ max_new_tokens = 1024
12
+ top_p = 0.9
13
+ temperature = 0.95
14
+ repetition_penalty = 1.0
15
+ device = 'cuda'
16
+ input_pattern = '<s>{}</s>'
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ trust_remote_code=True,
20
+ low_cpu_mem_usage=True,
21
+ torch_dtype=torch.float16,
22
+ device_map='auto'
23
+ ).to(device).eval()
24
+ tokenizer = AutoTokenizer.from_pretrained(
25
+ model_name,
26
+ trust_remote_code=True,
27
+ # llama不支持fast
28
+ use_fast=False if model.config.model_type == 'llama' else True
29
+ )
30
+ text = input('User:')
31
+ while True:
32
+ text = text.strip()
33
+ text = input_pattern.format(text)
34
+ input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
35
+ with torch.no_grad():
36
+ outputs = model.generate(
37
+ input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,
38
+ top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,
39
+ eos_token_id=tokenizer.eos_token_id
40
+ )
41
+ outputs = outputs.tolist()[0][len(input_ids[0]):]
42
+ response = tokenizer.decode(outputs)
43
+ response = response.strip().replace(text, "").replace('</s>', "").replace('<s>', "").strip()
44
+ print("Firefly:{}".format(response))
45
+ text = input('User:')
46
+
47
+
48
+ if __name__ == '__main__':
49
+ main()