File size: 4,222 Bytes
6dc0c9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
"""
Inference code for ChatGLM.
Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py.
"""
import re
import torch
from transformers.generation.logits_process import LogitsProcessor
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
invalid_score_processor = InvalidScoreLogitsProcessor()
def process_response(response):
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
punkts = [
[",", ","],
["!", "!"],
[":", ":"],
[";", ";"],
["\?", "?"],
]
for item in punkts:
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
return response
def recover_message_list(prompt):
role_token_pattern = "|".join(
[re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]]
)
role = None
last_end_idx = -1
message_list = []
for match in re.finditer(role_token_pattern, prompt):
if role:
messge = {}
if role == "<|system|>":
messge["role"] = "system"
elif role == "<|user|>":
messge["role"] = "user"
else:
messge["role"] = "assistant"
messge["content"] = prompt[last_end_idx + 1 : match.start()]
message_list.append(messge)
role = prompt[match.start() : match.end()]
last_end_idx = match.end()
return message_list
@torch.inference_mode()
def generate_stream_chatglm(
model,
tokenizer,
params,
device,
context_len=2048,
stream_interval=2,
judge_sent_end=False,
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
echo = params.get("echo", True)
model_type = str(type(model)).lower()
if "peft" in model_type:
model_type = str(type(model.base_model.model)).lower()
if "chatglm3" in model_type:
message_list = recover_message_list(prompt)
inputs = tokenizer.build_chat_input(
query=message_list[-1]["content"], history=message_list[:-1], role="user"
).to(model.device)
else:
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
input_echo_len = len(inputs["input_ids"][0])
gen_kwargs = {
"max_length": max_new_tokens + input_echo_len,
"do_sample": True if temperature > 1e-5 else False,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"logits_processor": [invalid_score_processor],
}
if temperature > 1e-5:
gen_kwargs["temperature"] = temperature
total_len = 0
for total_ids in model.stream_generate(**inputs, **gen_kwargs):
total_ids = total_ids.tolist()[0]
total_len = len(total_ids)
if echo:
output_ids = total_ids
else:
output_ids = total_ids[input_echo_len:]
response = tokenizer.decode(output_ids)
response = process_response(response)
yield {
"text": response,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": total_len - input_echo_len,
"total_tokens": total_len,
},
"finish_reason": None,
}
# TODO: ChatGLM stop when it reach max length
# Only last stream result contains finish_reason, we set finish_reason as stop
ret = {
"text": response,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": total_len - input_echo_len,
"total_tokens": total_len,
},
"finish_reason": "stop",
}
yield ret
|