File size: 2,515 Bytes
fee0ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd1d127
24209d2
fee0ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch 
from .BaseLLM import BaseLLM
from transformers import AutoTokenizer, AutoModel
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig

tokenizer_qwen = None
model_qwen = None



def initialize_Qwen2LORA(model):
    global model_qwen, tokenizer_qwen

    if model_qwen is None:
        model_qwen = AutoModelForCausalLM.from_pretrained(
            model,
            device_map="auto",
            trust_remote_code=True
        )
        model_qwen = model_qwen.eval()
        # model_qwen = PeftModel.from_pretrained(
        #     model_qwen,
        #     "silk-road/Chat-Haruhi-Fusion_B"
        # )

    if tokenizer_qwen is None:
        tokenizer_qwen = AutoTokenizer.from_pretrained(
            model, 
            # use_fast=True,
            trust_remote_code=True
        )

    return model_qwen, tokenizer_qwen

def Qwen_tokenizer(text):
    return len(tokenizer_qwen.encode(text))

class Qwen118k2GPT(BaseLLM):
    def __init__(self, model):
        super(Qwen118k2GPT, self).__init__()
        global model_qwen, tokenizer_qwen
        if model == "Qwen/Qwen-1_8B-Chat":
            tokenizer_qwen = AutoTokenizer.from_pretrained(
                "Qwen/Qwen-1_8B-Chat", 
                trust_remote_code=True
            )
            model_qwen = AutoModelForCausalLM.from_pretrained(
                "Qwen/Qwen-1_8B-Chat", 
                device_map="auto", 
                trust_remote_code=True
            ).eval()
            self.model = model_qwen
            self.tokenizer = tokenizer_qwen
        elif "silk-road/" in model :
            self.model, self.tokenizer = initialize_Qwen2LORA(model)
        else:
            raise Exception("Unknown Qwen model")
        self.messages = ""

    def initialize_message(self):
        self.messages = ""

    def ai_message(self, payload):
        self.messages = "AI: " +  self.messages + "\n " + payload 

    def system_message(self, payload):
        self.messages = "SYSTEM PROMPT: " + self.messages + "\n " + payload 

    def user_message(self, payload):
        self.messages = "User: " + self.messages + "\n " + payload 

    def get_response(self):
        with torch.no_grad():
            response, history = self.model.chat(self.tokenizer, self.messages, history=[])
            # print(response)
        return response
        
    def print_prompt(self):
        print(type(self.messages))
        print(self.messages)