WayneWuDH's picture
Upload 253 files
e6828c9
from abc import ABC
from langchain.llms.base import LLM
from typing import Optional, List
from models.loader import LoaderCheckPoint
from models.base import (BaseAnswer,
AnswerResult)
class ChatGLM(BaseAnswer, LLM, ABC):
max_token: int = 10000
temperature: float = 0.01
top_p = 0.9
checkPoint: LoaderCheckPoint = None
# history = []
history_len: int = 10
def __init__(self, checkPoint: LoaderCheckPoint = None):
super().__init__()
self.checkPoint = checkPoint
@property
def _llm_type(self) -> str:
return "ChatGLM"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def _history_len(self) -> int:
return self.history_len
def set_history_len(self, history_len: int = 10) -> None:
self.history_len = history_len
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
print(f"__call:{prompt}")
response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer,
prompt,
history=[],
max_length=self.max_token,
temperature=self.temperature
)
print(f"response:{response}")
print(f"+++++++++++++++++++++++++++++++++++")
return response
def generatorAnswer(self, prompt: str,
history: List[List[str]] = [],
streaming: bool = False):
if streaming:
history += [[]]
for inum, (stream_resp, _) in enumerate(self.checkPoint.model.stream_chat(
self.checkPoint.tokenizer,
prompt,
history=history[-self.history_len:-1] if self.history_len > 1 else [],
max_length=self.max_token,
temperature=self.temperature
)):
# self.checkPoint.clear_torch_cache()
history[-1] = [prompt, stream_resp]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": stream_resp}
yield answer_result
else:
response, _ = self.checkPoint.model.chat(
self.checkPoint.tokenizer,
prompt,
history=history[-self.history_len:] if self.history_len > 0 else [],
max_length=self.max_token,
temperature=self.temperature
)
self.checkPoint.clear_torch_cache()
history += [[prompt, response]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": response}
yield answer_result