File size: 4,234 Bytes
d319ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from string import Template
from typing import List, Dict

import torch.cuda
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM


aclient = None

client = None
tokenizer = None

END_POINT = "https://hf-mirror.com"


def init_client(model_name: str, verbose: bool) -> None:
    """
        初始化模型,通过可用的设备进行模型加载推理。

        Params:
            model_name (`str`)
                HuggingFace中的模型项目名,例如"THUDM/chatglm3-6b"
    """

    # 将client设置为全局变量
    global client
    global tokenizer

    # 判断 使用MPS、CUDA、CPU运行模型
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    if verbose:
        print("Using device: ", device)

    # TODO 上传模型后,更改为从huggingface获取模型
    client = AutoPeftModelForCausalLM.from_pretrained(
        model_name, trust_remote_code=True)
    tokenizer_dir = client.peft_config['default'].base_model_name_or_path
    if verbose:
        print(tokenizer_dir)
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_dir, trust_remote_code=True)

    # try:
    #     tokenizer = AutoTokenizer.from_pretrained(
    #         model_name, trust_remote_code=True, local_files_only=True)
    #     client = AutoModelForCausalLM.from_pretrained(
    #         model_name, trust_remote_code=True, local_files_only=True)
    # except Exception:
    #     if pretrained_model_download(model_name, verbose=verbose):
    #         tokenizer = AutoTokenizer.from_pretrained(
    #             model_name, trust_remote_code=True, local_files_only=True)
    #         client = AutoModelForCausalLM.from_pretrained(
    #             model_name, trust_remote_code=True, local_files_only=True)

    # client = client.to(device).eval()
    client = client.eval()


def pretrained_model_download(model_name_or_path: str, verbose: bool) -> bool:
    """
        使用huggingface_hub下载模型(model_name_or_path)。下载成功返回true,失败返回False。
        Params: 
            model_name_or_path (`str`): 模型的huggingface地址
        Returns:
            `bool` 是否下载成功
    """
    # TODO 使用hf镜像加速下载 未测试windows端

    # 判断是否使用HF_transfer,默认不使用。
    if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") == 1:
        try:
            import hf_transfer
        except ImportError:
            print("Install hf_transfer.")
            os.system("pip -q install hf_transfer")
            import hf_transfer

    # 尝试引入huggingface_hub
    try:
        import huggingface_hub
    except ImportError:
        print("Install huggingface_hub.")
        os.system("pip -q install huggingface_hub")
        import huggingface_hub

    # 使用huggingface_hub下载模型。
    try:
        print(f"downloading {model_name_or_path}")
        huggingface_hub.snapshot_download(
            repo_id=model_name_or_path, endpoint=END_POINT, resume_download=True, local_dir_use_symlinks=False)
    except Exception as e:
        raise e

    return True


def message2query(messages: List[Dict[str, str]]) -> str:
    # [{'role': 'user', 'content': '老师: 同学请自我介绍一下'}]
    # <|system|>
    # You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
    # <|user|>
    # Hello
    # <|assistant|>
    # Hello, I'm ChatGLM3. What can I assist you today?
    template = Template("<|$role|>\n$content\n")

    return "".join([template.substitute(message) for message in messages])


def get_response(message, model_name: str = "/workspace/jyh/Zero-Haruhi/checkpoint-1500", verbose: bool = True):
    global client
    global tokenizer

    if client is None:
        init_client(model_name, verbose=verbose)

    if verbose:
        print(message)
        print(message2query(message))

    response, history = client.chat(tokenizer, message2query(message))
    if verbose:
        print((response, history))

    return response