File size: 4,296 Bytes
2485df2
 
4fb7cc8
2485df2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef12332
 
 
2485df2
ef12332
 
 
2485df2
15aeb12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2485df2
15aeb12
 
2485df2
15aeb12
2485df2
 
15aeb12
2485df2
15aeb12
 
2485df2
15aeb12
2485df2
 
 
 
ef12332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2485df2
 
 
 
ef12332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2485df2
 
ef12332
 
 
2485df2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
import torch
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig


st.set_page_config(page_title="Baichuan-13B-Chat")
st.title("Baichuan-13B-Chat")


@st.cache_resource
def init_model():
    model = AutoModelForCausalLM.from_pretrained(
        "baichuan-inc/Baichuan-13B-Chat",
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    model.generation_config = GenerationConfig.from_pretrained(
        "baichuan-inc/Baichuan-13B-Chat"
    )
    tokenizer = AutoTokenizer.from_pretrained(
        "baichuan-inc/Baichuan-13B-Chat",
        use_fast=False,
        trust_remote_code=True
    )
    return model, tokenizer


# def clear_chat_history():
#     del st.session_state.messages
# 清空聊天历史
def clear_chat_history():
    # 把会话状态中的消息记录清空,并刷新页面
    st.session_state.messages = []
    st.experimental_rerun()

# def init_chat_history():
#     with st.chat_message("assistant", avatar='🤖'):
#         st.markdown("您好,我是百川大模型,很高兴为您服务🥰")

#     if "messages" in st.session_state:
#         for message in st.session_state.messages:
#             avatar = '🧑‍💻' if message["role"] == "user" else '🤖'
#             with st.chat_message(message["role"], avatar=avatar):
#                 st.markdown(message["content"])
#     else:
#         st.session_state.messages = []

#     return st.session_state.messages
    
# 初始化聊天历史
def init_chat_history():
    # 显示机器人的欢迎语
    st.write("🤖: 您好,我是百川大模型,很高兴为您服务🥰")

    # 如果会话状态中有消息记录,就显示出来
    if "messages" in st.session_state:
        for message in st.session_state.messages:
            # 根据角色显示不同的头像
            avatar = '🧑‍💻' if message["role"] == "user" else '🤖'
            # 用st.write或st.markdown显示消息内容
            st.write(f"{avatar}: {message['content']}")
    else:
        # 如果没有消息记录,就初始化一个空列表
        st.session_state.messages = []

    return st.session_state.messages

# def main():
#     model, tokenizer = init_model()
#     messages = init_chat_history()

#     if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
#         with st.chat_message("user", avatar='🧑‍💻'):
#             st.markdown(prompt)
#         messages.append({"role": "user", "content": prompt})
#         print(f"[user] {prompt}", flush=True)
#         with st.chat_message("assistant", avatar='🤖'):
#             placeholder = st.empty()
#             for response in model.chat(tokenizer, messages, stream=True):
#                 placeholder.markdown(response)
#                 if torch.backends.mps.is_available():
#                     torch.mps.empty_cache()
#         messages.append({"role": "assistant", "content": response})
#         print(json.dumps(messages, ensure_ascii=False), flush=True)

#         st.button("清空对话", on_click=clear_chat_history)
# 主函数
def main():
    model, tokenizer = init_model()
    messages = init_chat_history()

    # 在侧边栏或表单中创建一个输入框和一个提交按钮
    with st.sidebar.form("chat_form"):
        user_input = st.text_input("请输入您的问题")
        submit_button = st.form_submit_button("发送")

    # 如果用户点击了发送按钮,就把用户的输入添加到消息记录中,并显示出来
    if submit_button:
        messages.append({"role": "user", "content": user_input})
        st.write(f"🧑‍💻: {user_input}")
        print(f"[user] {user_input}", flush=True)

        # 调用模型来生成回复,并添加到消息记录中,并显示出来
        reply = model.chat(tokenizer, messages, stream=False)
        messages.append({"role": "assistant", "content": reply})
        st.write(f"🤖: {reply}")
        print(json.dumps(messages, ensure_ascii=False), flush=True)

        # 如果有清空对话的按钮,就绑定清空聊天历史的函数
        if st.button("清空对话"):
            clear_chat_history()

if __name__ == "__main__":
    main()