#!/usr/bin/env python # -*- coding:utf-8 _*- """ @author:quincy qiang @license: Apache Licence @file: generate.py @time: 2023/04/17 @contact: yanqiangmiffy@gamil.com @software: PyCharm @description: coding.. """ from typing import List, Optional from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from transformers import AutoModel, AutoTokenizer class ChatGLMService(LLM): max_token: int = 10000 temperature: float = 0.1 top_p = 0.9 history = [] tokenizer: object = None model: object = None def __init__(self): super().__init__() @property def _llm_type(self) -> str: return "ChatGLM" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: response, _ = self.model.chat( self.tokenizer, prompt, history=self.history, max_length=self.max_token, temperature=self.temperature, ) if stop is not None: response = enforce_stop_tokens(response, stop) self.history = self.history + [[None, response]] return response def load_model(self, model_name_or_path: str = "THUDM/chatglm-6b"): self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, trust_remote_code=True ) self.model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda() self.model=self.model.eval() # if __name__ == '__main__': # config=LangChainCFG() # chatLLM = ChatGLMService() # chatLLM.load_model(model_name_or_path=config.llm_model_name)