File size: 3,544 Bytes
730647b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# coding: utf-8
import warnings
warnings.filterwarnings("ignore")
import os
import re
import json
import torch
import pickle
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

client = None

def get_prompt(message):
    #prompt = system_info.format(role_name=role_name, persona=persona)
    persona = ""
    for msg in message:
        if msg["role"] == "system":
            persona = persona + msg["content"]
    prompt = "<<SYS>>" + persona + "<</SYS>>"
    from ChatHaruhi.utils import normalize2uaua
    message_ua = normalize2uaua(message[1:], if_replace_system = True)

    for i in range(0, len(message_ua)-1, 2):
        prompt = prompt + "[INST]" + message_ua[i]["content"] + "[/INST]" + message_ua[i+1]["content"] + "<|im_end|>"
    prompt = prompt + "[INST]" + message_ua[-1]["content"] + "[/INST]"
    print(prompt)
    return prompt

import os
class qwen_model:
    def __init__(self, model_name):
        self.DEVICE = torch.device("cuda")
        self.tokenizer = AutoTokenizer.from_pretrained(
            "silk-road/"+model_name,
            low_cpu_mem_usage=True,
            use_fast = False,
            padding_side="left",
            trust_remote_code=True
        )

        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        self.tokenizer.eos_token_id = 151645
        # print(tokenizer.eos_token_id)
        # print(tokenizer.pad_token_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            "silk-road/"+model_name,
            load_in_8bit=False,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            device_map='auto',
            trust_remote_code=True,
        ).eval()
        # model.to("cuda")
        # model.eval()
        # self.tokenizer = AutoTokenizer.from_pretrained("silk-road/"+model_name, trust_remote_code=True)
        # self.model = AutoModelForCausalLM.from_pretrained("silk-road/"+model_name, device_map="auto", trust_remote_code=True).eval()

    def get_response(self, message):
        with torch.inference_mode():
            prompt = get_prompt(message)
            batch = self.tokenizer(prompt, return_tensors="pt", padding=True)
            batch = self.tokenizer(prompt,
                             return_tensors="pt",
                             padding=True,
                             add_special_tokens=False)
            batch = {k: v.to(self.DEVICE) for k, v in batch.items()}
            generated = self.model.generate(input_ids=batch["input_ids"],
                                   max_new_tokens=1024,
                                   temperature=0.2,
                                   top_p=0.9,
                                   top_k=40,
                                   do_sample=False,
                                   num_beams=1,
                                   repetition_penalty=1.3,
                                   eos_token_id=self.tokenizer.eos_token_id,
                                   pad_token_id=self.tokenizer.pad_token_id)
            response = self.tokenizer.decode(generated[0][batch["input_ids"].shape[1]:]).strip().replace("<|im_end|>", "")
        return response


def init_client(model_name):

    # 将client设置为全局变量
    global client

    client = qwen_model(model_name = model_name)

def get_response(message, model_name = "Haruhi-Zero-1_8B-0_4"):
    if client is None:
        init_client(model_name)

    response = client.get_response(message)
    return response