|
from __future__ import annotations |
|
|
|
import logging |
|
import os |
|
import platform |
|
|
|
import gc |
|
import torch |
|
import colorama |
|
|
|
from ..index_func import * |
|
from ..presets import * |
|
from ..utils import * |
|
from .base_model import BaseLLMModel |
|
|
|
|
|
class ChatGLM_Client(BaseLLMModel): |
|
def __init__(self, model_name, user_name="") -> None: |
|
super().__init__(model_name=model_name, user=user_name) |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
global CHATGLM_TOKENIZER, CHATGLM_MODEL |
|
self.deinitialize() |
|
if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None: |
|
system_name = platform.system() |
|
model_path = None |
|
if os.path.exists("models"): |
|
model_dirs = os.listdir("models") |
|
if model_name in model_dirs: |
|
model_path = f"models/{model_name}" |
|
if model_path is not None: |
|
model_source = model_path |
|
else: |
|
model_source = f"THUDM/{model_name}" |
|
CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained( |
|
model_source, trust_remote_code=True |
|
) |
|
quantified = False |
|
if "int4" in model_name: |
|
quantified = True |
|
model = AutoModel.from_pretrained( |
|
model_source, trust_remote_code=True |
|
) |
|
if torch.cuda.is_available(): |
|
|
|
logging.info("CUDA is available, using CUDA") |
|
model = model.half().cuda() |
|
|
|
elif system_name == "Darwin" and model_path is not None and not quantified: |
|
logging.info("Running on macOS, using MPS") |
|
|
|
model = model.half().to("mps") |
|
else: |
|
logging.info("GPU is not available, using CPU") |
|
model = model.float() |
|
model = model.eval() |
|
CHATGLM_MODEL = model |
|
|
|
def _get_glm3_style_input(self): |
|
history = self.history |
|
query = history.pop()["content"] |
|
return history, query |
|
|
|
def _get_glm2_style_input(self): |
|
history = [x["content"] for x in self.history] |
|
query = history.pop() |
|
logging.debug(colorama.Fore.YELLOW + |
|
f"{history}" + colorama.Fore.RESET) |
|
assert ( |
|
len(history) % 2 == 0 |
|
), f"History should be even length. current history is: {history}" |
|
history = [[history[i], history[i + 1]] |
|
for i in range(0, len(history), 2)] |
|
return history, query |
|
|
|
def _get_glm_style_input(self): |
|
if "glm2" in self.model_name: |
|
return self._get_glm2_style_input() |
|
else: |
|
return self._get_glm3_style_input() |
|
|
|
def get_answer_at_once(self): |
|
history, query = self._get_glm_style_input() |
|
response, _ = CHATGLM_MODEL.chat( |
|
CHATGLM_TOKENIZER, query, history=history) |
|
return response, len(response) |
|
|
|
def get_answer_stream_iter(self): |
|
history, query = self._get_glm_style_input() |
|
for response, history in CHATGLM_MODEL.stream_chat( |
|
CHATGLM_TOKENIZER, |
|
query, |
|
history, |
|
max_length=self.token_upper_limit, |
|
top_p=self.top_p, |
|
temperature=self.temperature, |
|
): |
|
yield response |
|
|
|
def deinitialize(self): |
|
|
|
global CHATGLM_MODEL, CHATGLM_TOKENIZER |
|
CHATGLM_MODEL = None |
|
CHATGLM_TOKENIZER = None |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
logging.info("ChatGLM model deinitialized") |
|
|