demo_test / fastchat /model /model_chatglm.py
yuantao-infini-ai's picture
Upload 136 files
7472549 verified
raw
history blame
3.07 kB
"""
Inference code for ChatGLM.
Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py.
"""
import re
import torch
from transformers.generation.logits_process import LogitsProcessor
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
invalid_score_processor = InvalidScoreLogitsProcessor()
def process_response(response):
response = response.strip()
response = response.replace("[[训练时间]]", "2023年")
punkts = [
[",", ","],
["!", "!"],
[":", ":"],
[";", ";"],
["\?", "?"],
]
for item in punkts:
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
return response
@torch.inference_mode()
def generate_stream_chatglm(
model,
tokenizer,
params,
device,
context_len=2048,
stream_interval=2,
judge_sent_end=False,
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
echo = params.get("echo", True)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
input_echo_len = len(inputs["input_ids"][0])
gen_kwargs = {
"max_length": max_new_tokens + input_echo_len,
"do_sample": True if temperature > 1e-5 else False,
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"logits_processor": [invalid_score_processor],
}
if temperature > 1e-5:
gen_kwargs["temperature"] = temperature
total_len = 0
for total_ids in model.stream_generate(**inputs, **gen_kwargs):
total_ids = total_ids.tolist()[0]
total_len = len(total_ids)
if echo:
output_ids = total_ids
else:
output_ids = total_ids[input_echo_len:]
response = tokenizer.decode(output_ids)
response = process_response(response)
yield {
"text": response,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": total_len - input_echo_len,
"total_tokens": total_len,
},
"finish_reason": None,
}
# TODO: ChatGLM stop when it reach max length
# Only last stream result contains finish_reason, we set finish_reason as stop
ret = {
"text": response,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": total_len - input_echo_len,
"total_tokens": total_len,
},
"finish_reason": "stop",
}
yield ret