import json import torch import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.utils import GenerationConfig st.set_page_config(page_title="Baichuan-13B-Chat") st.title("Baichuan-13B-Chat") @st.cache_resource def init_model(): model = AutoModelForCausalLM.from_pretrained( "baichuan-inc/Baichuan-13B-Chat", torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) model.generation_config = GenerationConfig.from_pretrained( "baichuan-inc/Baichuan-13B-Chat" ) tokenizer = AutoTokenizer.from_pretrained( "baichuan-inc/Baichuan-13B-Chat", use_fast=False, trust_remote_code=True ) return model, tokenizer # def clear_chat_history(): # del st.session_state.messages # 清空聊天历史 def clear_chat_history(): # 把会话状态中的消息记录清空,并刷新页面 st.session_state.messages = [] st.experimental_rerun() # def init_chat_history(): # with st.chat_message("assistant", avatar='🤖'): # st.markdown("您好,我是百川大模型,很高兴为您服务🥰") # if "messages" in st.session_state: # for message in st.session_state.messages: # avatar = '🧑‍💻' if message["role"] == "user" else '🤖' # with st.chat_message(message["role"], avatar=avatar): # st.markdown(message["content"]) # else: # st.session_state.messages = [] # return st.session_state.messages # 初始化聊天历史 def init_chat_history(): # 显示机器人的欢迎语 st.write("🤖: 您好,我是百川大模型,很高兴为您服务🥰") # 如果会话状态中有消息记录,就显示出来 if "messages" in st.session_state: for message in st.session_state.messages: # 根据角色显示不同的头像 avatar = '🧑‍💻' if message["role"] == "user" else '🤖' # 用st.write或st.markdown显示消息内容 st.write(f"{avatar}: {message['content']}") else: # 如果没有消息记录,就初始化一个空列表 st.session_state.messages = [] return st.session_state.messages # def main(): # model, tokenizer = init_model() # messages = init_chat_history() # if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"): # with st.chat_message("user", avatar='🧑‍💻'): # st.markdown(prompt) # messages.append({"role": "user", "content": prompt}) # print(f"[user] {prompt}", flush=True) # with st.chat_message("assistant", avatar='🤖'): # placeholder = st.empty() # for response in model.chat(tokenizer, messages, stream=True): # placeholder.markdown(response) # if torch.backends.mps.is_available(): # torch.mps.empty_cache() # messages.append({"role": "assistant", "content": response}) # print(json.dumps(messages, ensure_ascii=False), flush=True) # st.button("清空对话", on_click=clear_chat_history) # 主函数 def main(): model, tokenizer = init_model() messages = init_chat_history() # 在侧边栏或表单中创建一个输入框和一个提交按钮 with st.sidebar.form("chat_form"): user_input = st.text_input("请输入您的问题") submit_button = st.form_submit_button("发送") # 如果用户点击了发送按钮,就把用户的输入添加到消息记录中,并显示出来 if submit_button: messages.append({"role": "user", "content": user_input}) st.write(f"🧑‍💻: {user_input}") print(f"[user] {user_input}", flush=True) # 调用模型来生成回复,并添加到消息记录中,并显示出来 reply = model.chat(tokenizer, messages, stream=False) messages.append({"role": "assistant", "content": reply}) st.write(f"🤖: {reply}") print(json.dumps(messages, ensure_ascii=False), flush=True) # 如果有清空对话的按钮,就绑定清空聊天历史的函数 if st.button("清空对话"): clear_chat_history() if __name__ == "__main__": main()