import json import torch import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.utils import GenerationConfig st.set_page_config(page_title="Baichuan-13B-Chat") st.title("Baichuan-13B-Chat") @st.cache_resource def init_model(): model = AutoModelForCausalLM.from_pretrained( "baichuan-inc/Baichuan-13B-Chat", torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) model.generation_config = GenerationConfig.from_pretrained( "baichuan-inc/Baichuan-13B-Chat" ) tokenizer = AutoTokenizer.from_pretrained( "baichuan-inc/Baichuan-13B-Chat", use_fast=False, trust_remote_code=True ) return model, tokenizer def clear_chat_history(): del st.session_state.messages # def init_chat_history(): # with st.chat_message("assistant", avatar='🤖'): # st.markdown("您好,我是百川大模型,很高兴为您服务🥰") # if "messages" in st.session_state: # for message in st.session_state.messages: # avatar = '🧑‍💻' if message["role"] == "user" else '🤖' # with st.chat_message(message["role"], avatar=avatar): # st.markdown(message["content"]) # else: # st.session_state.messages = [] # return st.session_state.messages # 初始化聊天历史 def init_chat_history(): # 显示机器人的欢迎语 st.write("🤖: 您好,我是百川大模型,很高兴为您服务🥰") # 如果会话状态中有消息记录,就显示出来 if "messages" in st.session_state: for message in st.session_state.messages: # 根据角色显示不同的头像 avatar = '🧑‍💻' if message["role"] == "user" else '🤖' # 用st.write或st.markdown显示消息内容 st.write(f"{avatar}: {message['content']}") else: # 如果没有消息记录,就初始化一个空列表 st.session_state.messages = [] return st.session_state.messages def main(): model, tokenizer = init_model() messages = init_chat_history() if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"): with st.chat_message("user", avatar='🧑‍💻'): st.markdown(prompt) messages.append({"role": "user", "content": prompt}) print(f"[user] {prompt}", flush=True) with st.chat_message("assistant", avatar='🤖'): placeholder = st.empty() for response in model.chat(tokenizer, messages, stream=True): placeholder.markdown(response) if torch.backends.mps.is_available(): torch.mps.empty_cache() messages.append({"role": "assistant", "content": response}) print(json.dumps(messages, ensure_ascii=False), flush=True) st.button("清空对话", on_click=clear_chat_history) if __name__ == "__main__": main()