import json import torch import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.utils import GenerationConfig st.set_page_config(page_title="Baichuan-13B-Chat") st.title("Baichuan-13B-Chat") @st.cache_resource def init_model(): model = AutoModelForCausalLM.from_pretrained( "baichuan-inc/Baichuan-13B-Chat", torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) model.generation_config = GenerationConfig.from_pretrained( "baichuan-inc/Baichuan-13B-Chat" ) tokenizer = AutoTokenizer.from_pretrained( "baichuan-inc/Baichuan-13B-Chat", use_fast=False, trust_remote_code=True ) model = model.quantize(8).cuda() return model, tokenizer def clear_chat_history(): del st.session_state.messages def init_chat_history(): with st.chat_message("assistant", avatar='🤖'): st.markdown("您好,我是百川大模型,很高兴为您服务🥰") if "messages" in st.session_state: for message in st.session_state.messages: avatar = '🧑‍💻' if message["role"] == "user" else '🤖' with st.chat_message(message["role"], avatar=avatar): st.markdown(message["content"]) else: st.session_state.messages = [] return st.session_state.messages def main(): model, tokenizer = init_model() messages = init_chat_history() if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"): with st.chat_message("user", avatar='🧑‍💻'): st.markdown(prompt) messages.append({"role": "user", "content": prompt}) print(f"[user] {prompt}", flush=True) with st.chat_message("assistant", avatar='🤖'): placeholder = st.empty() for response in model.chat(tokenizer, messages, stream=True): placeholder.markdown(response) if torch.backends.mps.is_available(): torch.mps.empty_cache() messages.append({"role": "assistant", "content": response}) print(json.dumps(messages, ensure_ascii=False), flush=True) st.button("清空对话", on_click=clear_chat_history) if __name__ == "__main__": main()