Spaces:
Runtime error
Runtime error
""" | |
Inference code for ChatGLM. | |
Adapted from https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py. | |
""" | |
import re | |
import torch | |
from transformers.generation.logits_process import LogitsProcessor | |
class InvalidScoreLogitsProcessor(LogitsProcessor): | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
) -> torch.FloatTensor: | |
if torch.isnan(scores).any() or torch.isinf(scores).any(): | |
scores.zero_() | |
scores[..., 5] = 5e4 | |
return scores | |
invalid_score_processor = InvalidScoreLogitsProcessor() | |
def process_response(response): | |
response = response.strip() | |
response = response.replace("[[训练时间]]", "2023年") | |
punkts = [ | |
[",", ","], | |
["!", "!"], | |
[":", ":"], | |
[";", ";"], | |
["\?", "?"], | |
] | |
for item in punkts: | |
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) | |
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) | |
return response | |
def generate_stream_chatglm( | |
model, | |
tokenizer, | |
params, | |
device, | |
context_len=2048, | |
stream_interval=2, | |
judge_sent_end=False, | |
): | |
prompt = params["prompt"] | |
temperature = float(params.get("temperature", 1.0)) | |
repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
max_new_tokens = int(params.get("max_new_tokens", 256)) | |
echo = params.get("echo", True) | |
inputs = tokenizer([prompt], return_tensors="pt").to(model.device) | |
input_echo_len = len(inputs["input_ids"][0]) | |
gen_kwargs = { | |
"max_length": max_new_tokens + input_echo_len, | |
"do_sample": True if temperature > 1e-5 else False, | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"logits_processor": [invalid_score_processor], | |
} | |
if temperature > 1e-5: | |
gen_kwargs["temperature"] = temperature | |
total_len = 0 | |
for total_ids in model.stream_generate(**inputs, **gen_kwargs): | |
total_ids = total_ids.tolist()[0] | |
total_len = len(total_ids) | |
if echo: | |
output_ids = total_ids | |
else: | |
output_ids = total_ids[input_echo_len:] | |
response = tokenizer.decode(output_ids) | |
response = process_response(response) | |
yield { | |
"text": response, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": total_len - input_echo_len, | |
"total_tokens": total_len, | |
}, | |
"finish_reason": None, | |
} | |
# TODO: ChatGLM stop when it reach max length | |
# Only last stream result contains finish_reason, we set finish_reason as stop | |
ret = { | |
"text": response, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": total_len - input_echo_len, | |
"total_tokens": total_len, | |
}, | |
"finish_reason": "stop", | |
} | |
yield ret | |