silk-road commited on
Commit
730647b
·
verified ·
1 Parent(s): 10b3b99

Upload response_qwen_base.py

Browse files
Files changed (1) hide show
  1. ChatHaruhi/response_qwen_base.py +95 -0
ChatHaruhi/response_qwen_base.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ import warnings
3
+ warnings.filterwarnings("ignore")
4
+ import os
5
+ import re
6
+ import json
7
+ import torch
8
+ import pickle
9
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
10
+
11
+ client = None
12
+
13
+ def get_prompt(message):
14
+ #prompt = system_info.format(role_name=role_name, persona=persona)
15
+ persona = ""
16
+ for msg in message:
17
+ if msg["role"] == "system":
18
+ persona = persona + msg["content"]
19
+ prompt = "<<SYS>>" + persona + "<</SYS>>"
20
+ from ChatHaruhi.utils import normalize2uaua
21
+ message_ua = normalize2uaua(message[1:], if_replace_system = True)
22
+
23
+ for i in range(0, len(message_ua)-1, 2):
24
+ prompt = prompt + "[INST]" + message_ua[i]["content"] + "[/INST]" + message_ua[i+1]["content"] + "<|im_end|>"
25
+ prompt = prompt + "[INST]" + message_ua[-1]["content"] + "[/INST]"
26
+ print(prompt)
27
+ return prompt
28
+
29
+ import os
30
+ class qwen_model:
31
+ def __init__(self, model_name):
32
+ self.DEVICE = torch.device("cuda")
33
+ self.tokenizer = AutoTokenizer.from_pretrained(
34
+ "silk-road/"+model_name,
35
+ low_cpu_mem_usage=True,
36
+ use_fast = False,
37
+ padding_side="left",
38
+ trust_remote_code=True
39
+ )
40
+
41
+ if self.tokenizer.pad_token is None:
42
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
43
+ self.tokenizer.eos_token_id = 151645
44
+ # print(tokenizer.eos_token_id)
45
+ # print(tokenizer.pad_token_id)
46
+ self.model = AutoModelForCausalLM.from_pretrained(
47
+ "silk-road/"+model_name,
48
+ load_in_8bit=False,
49
+ torch_dtype=torch.bfloat16,
50
+ low_cpu_mem_usage=True,
51
+ device_map='auto',
52
+ trust_remote_code=True,
53
+ ).eval()
54
+ # model.to("cuda")
55
+ # model.eval()
56
+ # self.tokenizer = AutoTokenizer.from_pretrained("silk-road/"+model_name, trust_remote_code=True)
57
+ # self.model = AutoModelForCausalLM.from_pretrained("silk-road/"+model_name, device_map="auto", trust_remote_code=True).eval()
58
+
59
+ def get_response(self, message):
60
+ with torch.inference_mode():
61
+ prompt = get_prompt(message)
62
+ batch = self.tokenizer(prompt, return_tensors="pt", padding=True)
63
+ batch = self.tokenizer(prompt,
64
+ return_tensors="pt",
65
+ padding=True,
66
+ add_special_tokens=False)
67
+ batch = {k: v.to(self.DEVICE) for k, v in batch.items()}
68
+ generated = self.model.generate(input_ids=batch["input_ids"],
69
+ max_new_tokens=1024,
70
+ temperature=0.2,
71
+ top_p=0.9,
72
+ top_k=40,
73
+ do_sample=False,
74
+ num_beams=1,
75
+ repetition_penalty=1.3,
76
+ eos_token_id=self.tokenizer.eos_token_id,
77
+ pad_token_id=self.tokenizer.pad_token_id)
78
+ response = self.tokenizer.decode(generated[0][batch["input_ids"].shape[1]:]).strip().replace("<|im_end|>", "")
79
+ return response
80
+
81
+
82
+ def init_client(model_name):
83
+
84
+ # 将client设置为全局变量
85
+ global client
86
+
87
+ client = qwen_model(model_name = model_name)
88
+
89
+ def get_response(message, model_name = "Haruhi-Zero-1_8B-0_4"):
90
+ if client is None:
91
+ init_client(model_name)
92
+
93
+ response = client.get_response(message)
94
+ return response
95
+