File size: 2,270 Bytes
fee0ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch 
from .BaseLLM import BaseLLM
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel

tokenizer_GLM = None
model_GLM = None

def initialize_GLM2LORA():
    global model_GLM, tokenizer_GLM

    if model_GLM is None:
        model_GLM = AutoModel.from_pretrained(
            "THUDM/chatglm2-6b",
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        model_GLM = PeftModel.from_pretrained(
            model_GLM,
            "silk-road/Chat-Haruhi-Fusion_B"
        )

    if tokenizer_GLM is None:
        tokenizer_GLM = AutoTokenizer.from_pretrained(
            "THUDM/chatglm2-6b", 
            use_fast=True,
            trust_remote_code=True
        )

    return model_GLM, tokenizer_GLM

def GLM_tokenizer(text):
    return len(tokenizer_GLM.encode(text))

class ChatGLM2GPT(BaseLLM):
    def __init__(self, model = "haruhi-fusion"):
        super(ChatGLM2GPT, self).__init__()
        if model == "glm2-6b":
            self.tokenizer = AutoTokenizer.from_pretrained(
                "THUDM/chatglm2-6b", 
                use_fast=True,
                trust_remote_code=True
            )
            self.model = AutoModel.from_pretrained(
                "THUDM/chatglm2-6b",
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )
        if model == "haruhi-fusion":
            self.model, self.tokenizer = initialize_GLM2LORA()
        else:
            raise Exception("Unknown GLM model")
        self.messages = ""

    def initialize_message(self):
        self.messages = ""

    def ai_message(self, payload):
        self.messages = self.messages + "\n " + payload 

    def system_message(self, payload):
        self.messages = self.messages + "\n " + payload 

    def user_message(self, payload):
        self.messages = self.messages + "\n " + payload 

    def get_response(self):
        with torch.no_grad():
            response, history = self.model.chat(self.tokenizer, self.messages, history=[])
            # print(response)
        return response
        
    def print_prompt(self):
        print(type(self.messages))
        print(self.messages)