import time from abc import ABC, abstractmethod from typing import List, Tuple import torch from transformers import AutoModel, AutoTokenizer from transformers import LogitsProcessor, LogitsProcessorList from .singleton import Singleton def parse_codeblock(text): lines = text.split("\n") for i, line in enumerate(lines): if "```" in line: if line != "```": lines[i] = f'
'
            else:
                lines[i] = '
' else: if i > 0: lines[i] = "
" + line.replace("<", "<").replace(">", ">") return "".join(lines) class BasePredictor(ABC): @abstractmethod def __init__(self, model_name): self.model = None self.tokenizer = None @abstractmethod def stream_chat_continue(self, *args, **kwargs): raise NotImplementedError def predict_continue(self, query, latest_message, max_length, top_p, temperature, allow_generate, history, *args, **kwargs): if history is None: history = [] allow_generate[0] = True history.append((query, latest_message)) for response in self.stream_chat_continue( self.model, self.tokenizer, query=query, history=history, max_length=max_length, top_p=top_p, temperature=temperature): history[-1] = (history[-1][0], response) yield history, '', '' if not allow_generate[0]: break class InvalidScoreLogitsProcessor(LogitsProcessor): def __init__(self, start_pos=20005): self.start_pos = start_pos def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() scores[..., self.start_pos] = 5e4 return scores class ChatGLM(BasePredictor): def __init__(self, model_name="THUDM/chatglm-6b-int4"): print(f'Loading model {model_name}') start = time.perf_counter() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, resume_download=True ) model = AutoModel.from_pretrained( model_name, trust_remote_code=True, resume_download=True ).half().to(self.device) model = model.eval() self.model = model self.model_name = model_name end = time.perf_counter() print( f'Successfully loaded model {model_name}, time cost: {end - start:.2f}s' ) @torch.no_grad() def generator_image_text(self, text): response, history = self.model.chat(self.tokenizer, "描述画面:{}".format(text), history=[]) return response @torch.no_grad() def stream_chat_continue(self, model, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): if history is None: history = [] if logits_processor is None: logits_processor = LogitsProcessorList() if len(history) > 0: answer = history[-1][1] else: answer = '' logits_processor.append( InvalidScoreLogitsProcessor( start_pos=20005 if 'slim' not in self.model_name else 5)) gen_kwargs = { "max_length": max_length, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, "logits_processor": logits_processor, **kwargs } if not history: prompt = query else: prompt = "" for i, (old_query, response) in enumerate(history): if i != len(history) - 1: prompt += "[Round {}]\n问:{}\n答:{}\n".format( i, old_query, response) else: prompt += "[Round {}]\n问:{}\n答:".format(i, old_query) batch_input = tokenizer([prompt], return_tensors="pt", padding=True) batch_input = batch_input.to(model.device) batch_answer = tokenizer(answer, return_tensors="pt") batch_answer = batch_answer.to(model.device) input_length = len(batch_input['input_ids'][0]) final_input_ids = torch.cat( [batch_input['input_ids'], batch_answer['input_ids'][:, :-2]], dim=-1).cuda() attention_mask = model.get_masks( final_input_ids, device=final_input_ids.device) batch_input['input_ids'] = final_input_ids batch_input['attention_mask'] = attention_mask input_ids = final_input_ids MASK, gMASK = self.model.config.bos_token_id - 4, self.model.config.bos_token_id - 3 mask_token = MASK if MASK in input_ids else gMASK mask_positions = [seq.tolist().index(mask_token) for seq in input_ids] batch_input['position_ids'] = self.model.get_position_ids( input_ids, mask_positions, device=input_ids.device) for outputs in model.stream_generate(**batch_input, **gen_kwargs): outputs = outputs.tolist()[0][input_length:] response = tokenizer.decode(outputs) response = model.process_response(response) yield parse_codeblock(response) @Singleton class Models(object): def __getattr__(self, item): if item in self.__dict__: return getattr(self, item) if item == 'chatglm': self.chatglm = ChatGLM("THUDM/chatglm-6b-int4") return getattr(self, item) models = Models.instance() def chat2text(text: str) -> str: return models.chatglm.generator_image_text(text)