Spaces:
Sleeping
Sleeping
import os.path as osp | |
from transformers import AutoTokenizer, AutoModel | |
from transformers.generation import GenerationConfig | |
from typing import Dict, List, Optional, Union | |
class ChatGLM2Wrapper: | |
def __init__(self, | |
model_path: str = 'THUDM/chatglm2-6b-int4', | |
system_prompt: str = None, | |
temperature: float = 0, | |
**model_kwargs): | |
self.system_prompt = system_prompt | |
self.temperature = temperature | |
self.model_path=model_path | |
assert osp.exists(model_path) or len(model_path.split('/')) == 2 | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda() | |
try: | |
self.generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True) | |
self.model.generation_config = self.generation_config | |
except: | |
pass | |
self.model = self.model.eval() | |
self.context_length = self.model.config.seq_length | |
self.answer_buffer = 192 | |
for k, v in model_kwargs.items(): | |
print(f'Following args are passed but not used to initialize the model, {k}: {v}. ') | |
def length_ok(self, inputs): | |
tot = len(self.tokenizer.encode(self.system_prompt)) if self.system_prompt is not None else 0 | |
for s in inputs: | |
tot += len(self.tokenizer.encode(s)) | |
return tot + self.answer_buffer < self.context_length | |
def chat(self, full_inputs: Union[str, List[str]], offset=0) -> str: | |
inputs = full_inputs[offset:] | |
if not self.length_ok(inputs): | |
return self.chat(full_inputs, offset + 1) | |
history_base, history, msg = [], [], None | |
if len(inputs) % 2 == 1: | |
if self.system_prompt is not None: | |
history_base = [(self.system_prompt, '')] | |
for i in range(len(inputs)//2): | |
history.append((inputs[2 * i], inputs[2 * i + 1])) | |
msg = inputs[-1] | |
else: | |
assert self.system_prompt is not None | |
history_base = [(self.system_prompt, inputs[0])] | |
for i in range(len(inputs) // 2 - 1): | |
history.append((inputs[2 * i + 1], inputs[2 * i + 2])) | |
msg = inputs[-1] | |
response, _ = self.model.chat(self.tokenizer, msg, history=history_base + history, do_sample=False, temperature=self.temperature) | |
return response, offset | |