File size: 3,739 Bytes
0bae6cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import annotations

import logging
import os
import platform

import gc
import torch
import colorama

from ..index_func import *
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel


class ChatGLM_Client(BaseLLMModel):
    def __init__(self, model_name, user_name="") -> None:
        super().__init__(model_name=model_name, user=user_name)
        import torch
        from transformers import AutoModel, AutoTokenizer
        global CHATGLM_TOKENIZER, CHATGLM_MODEL
        self.deinitialize()
        if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None:
            system_name = platform.system()
            model_path = None
            if os.path.exists("models"):
                model_dirs = os.listdir("models")
                if model_name in model_dirs:
                    model_path = f"models/{model_name}"
            if model_path is not None:
                model_source = model_path
            else:
                model_source = f"THUDM/{model_name}"
            CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(
                model_source, trust_remote_code=True
            )
            quantified = False
            if "int4" in model_name:
                quantified = True
            model = AutoModel.from_pretrained(
                model_source, trust_remote_code=True
            )
            if torch.cuda.is_available():
                # run on CUDA
                logging.info("CUDA is available, using CUDA")
                model = model.half().cuda()
            # mps加速还存在一些问题,暂时不使用
            elif system_name == "Darwin" and model_path is not None and not quantified:
                logging.info("Running on macOS, using MPS")
                # running on macOS and model already downloaded
                model = model.half().to("mps")
            else:
                logging.info("GPU is not available, using CPU")
                model = model.float()
            model = model.eval()
            CHATGLM_MODEL = model

    def _get_glm3_style_input(self):
        history = self.history
        query = history.pop()["content"]
        return history, query

    def _get_glm2_style_input(self):
        history = [x["content"] for x in self.history]
        query = history.pop()
        logging.debug(colorama.Fore.YELLOW +
                      f"{history}" + colorama.Fore.RESET)
        assert (
            len(history) % 2 == 0
        ), f"History should be even length. current history is: {history}"
        history = [[history[i], history[i + 1]]
                   for i in range(0, len(history), 2)]
        return history, query

    def _get_glm_style_input(self):
        if "glm2" in self.model_name:
            return self._get_glm2_style_input()
        else:
            return self._get_glm3_style_input()

    def get_answer_at_once(self):
        history, query = self._get_glm_style_input()
        response, _ = CHATGLM_MODEL.chat(
            CHATGLM_TOKENIZER, query, history=history)
        return response, len(response)

    def get_answer_stream_iter(self):
        history, query = self._get_glm_style_input()
        for response, history in CHATGLM_MODEL.stream_chat(
            CHATGLM_TOKENIZER,
            query,
            history,
            max_length=self.token_upper_limit,
            top_p=self.top_p,
            temperature=self.temperature,
        ):
            yield response

    def deinitialize(self):
        # 释放显存
        global CHATGLM_MODEL, CHATGLM_TOKENIZER
        CHATGLM_MODEL = None
        CHATGLM_TOKENIZER = None
        gc.collect()
        torch.cuda.empty_cache()
        logging.info("ChatGLM model deinitialized")