|
import json |
|
import torch |
|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation.utils import GenerationConfig |
|
|
|
|
|
st.set_page_config(page_title="Baichuan-13B-Chat") |
|
st.title("Baichuan-13B-Chat") |
|
|
|
|
|
@st.cache_resource |
|
def init_model(): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"baichuan-inc/Baichuan-13B-Chat", |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
model.generation_config = GenerationConfig.from_pretrained( |
|
"baichuan-inc/Baichuan-13B-Chat" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"baichuan-inc/Baichuan-13B-Chat", |
|
use_fast=False, |
|
trust_remote_code=True |
|
) |
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
def clear_chat_history(): |
|
|
|
st.session_state.messages = [] |
|
st.experimental_rerun() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_chat_history(): |
|
|
|
st.write("🤖: 您好,我是百川大模型,很高兴为您服务🥰") |
|
|
|
|
|
if "messages" in st.session_state: |
|
for message in st.session_state.messages: |
|
|
|
avatar = '🧑💻' if message["role"] == "user" else '🤖' |
|
|
|
st.write(f"{avatar}: {message['content']}") |
|
else: |
|
|
|
st.session_state.messages = [] |
|
|
|
return st.session_state.messages |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
model, tokenizer = init_model() |
|
messages = init_chat_history() |
|
|
|
|
|
with st.sidebar.form("chat_form"): |
|
user_input = st.text_input("请输入您的问题") |
|
submit_button = st.form_submit_button("发送") |
|
|
|
|
|
if submit_button: |
|
messages.append({"role": "user", "content": user_input}) |
|
st.write(f"🧑💻: {user_input}") |
|
print(f"[user] {user_input}", flush=True) |
|
|
|
|
|
reply = model.chat(tokenizer, messages, stream=False) |
|
messages.append({"role": "assistant", "content": reply}) |
|
st.write(f"🤖: {reply}") |
|
print(json.dumps(messages, ensure_ascii=False), flush=True) |
|
|
|
|
|
if st.button("清空对话"): |
|
clear_chat_history() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|