File size: 3,002 Bytes
2485df2
 
4fb7cc8
2485df2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15aeb12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2485df2
15aeb12
 
2485df2
15aeb12
2485df2
 
15aeb12
2485df2
15aeb12
 
2485df2
15aeb12
2485df2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
import torch
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig


st.set_page_config(page_title="Baichuan-13B-Chat")
st.title("Baichuan-13B-Chat")


@st.cache_resource
def init_model():
    model = AutoModelForCausalLM.from_pretrained(
        "baichuan-inc/Baichuan-13B-Chat",
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    model.generation_config = GenerationConfig.from_pretrained(
        "baichuan-inc/Baichuan-13B-Chat"
    )
    tokenizer = AutoTokenizer.from_pretrained(
        "baichuan-inc/Baichuan-13B-Chat",
        use_fast=False,
        trust_remote_code=True
    )
    return model, tokenizer


def clear_chat_history():
    del st.session_state.messages


# def init_chat_history():
#     with st.chat_message("assistant", avatar='🤖'):
#         st.markdown("您好,我是百川大模型,很高兴为您服务🥰")

#     if "messages" in st.session_state:
#         for message in st.session_state.messages:
#             avatar = '🧑‍💻' if message["role"] == "user" else '🤖'
#             with st.chat_message(message["role"], avatar=avatar):
#                 st.markdown(message["content"])
#     else:
#         st.session_state.messages = []

#     return st.session_state.messages
    
# 初始化聊天历史
def init_chat_history():
    # 显示机器人的欢迎语
    st.write("🤖: 您好,我是百川大模型,很高兴为您服务🥰")

    # 如果会话状态中有消息记录,就显示出来
    if "messages" in st.session_state:
        for message in st.session_state.messages:
            # 根据角色显示不同的头像
            avatar = '🧑‍💻' if message["role"] == "user" else '🤖'
            # 用st.write或st.markdown显示消息内容
            st.write(f"{avatar}: {message['content']}")
    else:
        # 如果没有消息记录,就初始化一个空列表
        st.session_state.messages = []

    return st.session_state.messages

def main():
    model, tokenizer = init_model()
    messages = init_chat_history()

    if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
        with st.chat_message("user", avatar='🧑‍💻'):
            st.markdown(prompt)
        messages.append({"role": "user", "content": prompt})
        print(f"[user] {prompt}", flush=True)
        with st.chat_message("assistant", avatar='🤖'):
            placeholder = st.empty()
            for response in model.chat(tokenizer, messages, stream=True):
                placeholder.markdown(response)
                if torch.backends.mps.is_available():
                    torch.mps.empty_cache()
        messages.append({"role": "assistant", "content": response})
        print(json.dumps(messages, ensure_ascii=False), flush=True)

        st.button("清空对话", on_click=clear_chat_history)


if __name__ == "__main__":
    main()