File size: 4,296 Bytes
2485df2 4fb7cc8 2485df2 ef12332 2485df2 ef12332 2485df2 15aeb12 2485df2 15aeb12 2485df2 15aeb12 2485df2 15aeb12 2485df2 15aeb12 2485df2 15aeb12 2485df2 ef12332 2485df2 ef12332 2485df2 ef12332 2485df2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import json
import torch
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
st.set_page_config(page_title="Baichuan-13B-Chat")
st.title("Baichuan-13B-Chat")
@st.cache_resource
def init_model():
model = AutoModelForCausalLM.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat"
)
tokenizer = AutoTokenizer.from_pretrained(
"baichuan-inc/Baichuan-13B-Chat",
use_fast=False,
trust_remote_code=True
)
return model, tokenizer
# def clear_chat_history():
# del st.session_state.messages
# 清空聊天历史
def clear_chat_history():
# 把会话状态中的消息记录清空,并刷新页面
st.session_state.messages = []
st.experimental_rerun()
# def init_chat_history():
# with st.chat_message("assistant", avatar='🤖'):
# st.markdown("您好,我是百川大模型,很高兴为您服务🥰")
# if "messages" in st.session_state:
# for message in st.session_state.messages:
# avatar = '🧑💻' if message["role"] == "user" else '🤖'
# with st.chat_message(message["role"], avatar=avatar):
# st.markdown(message["content"])
# else:
# st.session_state.messages = []
# return st.session_state.messages
# 初始化聊天历史
def init_chat_history():
# 显示机器人的欢迎语
st.write("🤖: 您好,我是百川大模型,很高兴为您服务🥰")
# 如果会话状态中有消息记录,就显示出来
if "messages" in st.session_state:
for message in st.session_state.messages:
# 根据角色显示不同的头像
avatar = '🧑💻' if message["role"] == "user" else '🤖'
# 用st.write或st.markdown显示消息内容
st.write(f"{avatar}: {message['content']}")
else:
# 如果没有消息记录,就初始化一个空列表
st.session_state.messages = []
return st.session_state.messages
# def main():
# model, tokenizer = init_model()
# messages = init_chat_history()
# if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
# with st.chat_message("user", avatar='🧑💻'):
# st.markdown(prompt)
# messages.append({"role": "user", "content": prompt})
# print(f"[user] {prompt}", flush=True)
# with st.chat_message("assistant", avatar='🤖'):
# placeholder = st.empty()
# for response in model.chat(tokenizer, messages, stream=True):
# placeholder.markdown(response)
# if torch.backends.mps.is_available():
# torch.mps.empty_cache()
# messages.append({"role": "assistant", "content": response})
# print(json.dumps(messages, ensure_ascii=False), flush=True)
# st.button("清空对话", on_click=clear_chat_history)
# 主函数
def main():
model, tokenizer = init_model()
messages = init_chat_history()
# 在侧边栏或表单中创建一个输入框和一个提交按钮
with st.sidebar.form("chat_form"):
user_input = st.text_input("请输入您的问题")
submit_button = st.form_submit_button("发送")
# 如果用户点击了发送按钮,就把用户的输入添加到消息记录中,并显示出来
if submit_button:
messages.append({"role": "user", "content": user_input})
st.write(f"🧑💻: {user_input}")
print(f"[user] {user_input}", flush=True)
# 调用模型来生成回复,并添加到消息记录中,并显示出来
reply = model.chat(tokenizer, messages, stream=False)
messages.append({"role": "assistant", "content": reply})
st.write(f"🤖: {reply}")
print(json.dumps(messages, ensure_ascii=False), flush=True)
# 如果有清空对话的按钮,就绑定清空聊天历史的函数
if st.button("清空对话"):
clear_chat_history()
if __name__ == "__main__":
main()
|