File size: 3,429 Bytes
96a12df
2512d56
96a12df
 
2bb84b9
 
96a12df
 
 
 
 
2bb84b9
 
96a12df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14991de
96a12df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bae132e
 
 
96a12df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d5981a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import gradio as gr
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
#device = "cuda"  # the device to load the model onto
device = "cpu"  # the device to load the model onto


bot_avatar = "shuaikang/dl_logo_rect.png"           # 聊天机器人头像位置
user_avatar = "shuaikang/user_avatar.jpg"           # 用户头像位置
#model_path = "sethuiyer/Medichat-Llama3-8B"   # 已下载的模型位置
model_path = "johnsnowlabs/JSL-MedMX-7X"
#model_path = "aaditya/Llama3-OpenBioLLM-8B"

# 存储全局的历史对话记录,Llama3支持系统prompt,所以这里默认设置!
llama3_chat_history = [
    {"role": "system", "content": "You are a helpful assistant trained by MetaAI! But you are running with DataLearnerAI Code."}
]

# 初始化所有变量,用于载入模型
tokenizer = None
streamer = None
model = None
terminators = None


def init_model():
    """初始化模型,载入本地模型
    """
    global tokenizer, model, streamer, terminators
    tokenizer = AutoTokenizer.from_pretrained(
        model_path, trust_remote_code=True)

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        device_map=device,
        trust_remote_code=True
    )

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True
    )


with gr.Blocks() as demo:
    # step1: 载入模型
    init_model()

    # step2: 初始化gradio的chatbot应用,并添加按钮等信息
    chatbot = gr.Chatbot(
        height=900,
        avatar_images=(user_avatar, bot_avatar)
    )
    msg = gr.Textbox()
    clear = gr.ClearButton([msg, chatbot])

    # 清楚历史记录
    def clear_history():
        global llama3_chat_history
        llama3_chat_history = []

    # 用于回复的方法
    def respond(message, chat_history):

        global llama3_chat_history, tokenizer, model, streamer

        llama3_chat_history.append({"role": "user", "content": message})

        # 使用Llama3自带的聊天模板,格式化对话记录
        history_str = tokenizer.apply_chat_template(
            llama3_chat_history,
            tokenize=False,
            add_generation_prompt=True
        )

        # tokenzier
        inputs = tokenizer(history_str, return_tensors='pt').to(device)

        chat_history.append([message, ""])

        generation_kwargs = dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=4096,
            num_beams=1,
            do_sample=True,
            top_p=0.8,
            temperature=0.3,
            eos_token_id=terminators
        )

        # 启动线程,用以监控流失输出结果
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        for new_text in streamer:
            chat_history[-1][1] += new_text
            yield "", chat_history

        llama3_chat_history.append(
            {"role": "assistant", "content": chat_history[-1][1]}
        )

    # 点击清楚按钮,触发历史记录清楚
    clear.click(clear_history)
    msg.submit(respond, [msg, chatbot], [msg, chatbot])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)