IMPChat / modules /models /ChatGLM.py
MILVLG's picture
Upload 107 files
0bae6cd verified
raw
history blame
3.74 kB
from __future__ import annotations
import logging
import os
import platform
import gc
import torch
import colorama
from ..index_func import *
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel
class ChatGLM_Client(BaseLLMModel):
def __init__(self, model_name, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
import torch
from transformers import AutoModel, AutoTokenizer
global CHATGLM_TOKENIZER, CHATGLM_MODEL
self.deinitialize()
if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
system_name = platform.system()
model_path = None
if os.path.exists("models"):
model_dirs = os.listdir("models")
if model_name in model_dirs:
model_path = f"models/{model_name}"
if model_path is not None:
model_source = model_path
else:
model_source = f"THUDM/{model_name}"
CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
model_source, trust_remote_code=True
)
quantified = False
if "int4" in model_name:
quantified = True
model = AutoModel.from_pretrained(
model_source, trust_remote_code=True
)
if torch.cuda.is_available():
# run on CUDA
logging.info("CUDA is available, using CUDA")
model = model.half().cuda()
# mps加速还存在一些问题,暂时不使用
elif system_name == "Darwin" and model_path is not None and not quantified:
logging.info("Running on macOS, using MPS")
# running on macOS and model already downloaded
model = model.half().to("mps")
else:
logging.info("GPU is not available, using CPU")
model = model.float()
model = model.eval()
CHATGLM_MODEL = model
def _get_glm3_style_input(self):
history = self.history
query = history.pop()["content"]
return history, query
def _get_glm2_style_input(self):
history = [x["content"] for x in self.history]
query = history.pop()
logging.debug(colorama.Fore.YELLOW +
f"{history}" + colorama.Fore.RESET)
assert (
len(history) % 2 == 0
), f"History should be even length. current history is: {history}"
history = [[history[i], history[i + 1]]
for i in range(0, len(history), 2)]
return history, query
def _get_glm_style_input(self):
if "glm2" in self.model_name:
return self._get_glm2_style_input()
else:
return self._get_glm3_style_input()
def get_answer_at_once(self):
history, query = self._get_glm_style_input()
response, _ = CHATGLM_MODEL.chat(
CHATGLM_TOKENIZER, query, history=history)
return response, len(response)
def get_answer_stream_iter(self):
history, query = self._get_glm_style_input()
for response, history in CHATGLM_MODEL.stream_chat(
CHATGLM_TOKENIZER,
query,
history,
max_length=self.token_upper_limit,
top_p=self.top_p,
temperature=self.temperature,
):
yield response
def deinitialize(self):
# 释放显存
global CHATGLM_MODEL, CHATGLM_TOKENIZER
CHATGLM_MODEL = None
CHATGLM_TOKENIZER = None
gc.collect()
torch.cuda.empty_cache()
logging.info("ChatGLM model deinitialized")