kaixuan42 commited on
Commit
2485df2
1 Parent(s): 8e69c27

Upload web_demo.py

Browse files
Files changed (1) hide show
  1. web_demo.py +72 -0
web_demo.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import streamlit as st
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from transformers.generation.utils import GenerationConfig
6
+
7
+
8
+ st.set_page_config(page_title="Baichuan-13B-Chat")
9
+ st.title("Baichuan-13B-Chat")
10
+
11
+
12
+ @st.cache_resource
13
+ def init_model():
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ "baichuan-inc/Baichuan-13B-Chat",
16
+ torch_dtype=torch.float16,
17
+ device_map="auto",
18
+ trust_remote_code=True
19
+ )
20
+ model.generation_config = GenerationConfig.from_pretrained(
21
+ "baichuan-inc/Baichuan-13B-Chat"
22
+ )
23
+ tokenizer = AutoTokenizer.from_pretrained(
24
+ "baichuan-inc/Baichuan-13B-Chat",
25
+ use_fast=False,
26
+ trust_remote_code=True
27
+ )
28
+ return model, tokenizer
29
+
30
+
31
+ def clear_chat_history():
32
+ del st.session_state.messages
33
+
34
+
35
+ def init_chat_history():
36
+ with st.chat_message("assistant", avatar='🤖'):
37
+ st.markdown("您好,我是百川大模型,很高兴为您服务🥰")
38
+
39
+ if "messages" in st.session_state:
40
+ for message in st.session_state.messages:
41
+ avatar = '🧑‍💻' if message["role"] == "user" else '🤖'
42
+ with st.chat_message(message["role"], avatar=avatar):
43
+ st.markdown(message["content"])
44
+ else:
45
+ st.session_state.messages = []
46
+
47
+ return st.session_state.messages
48
+
49
+
50
+ def main():
51
+ model, tokenizer = init_model()
52
+ messages = init_chat_history()
53
+
54
+ if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
55
+ with st.chat_message("user", avatar='🧑‍💻'):
56
+ st.markdown(prompt)
57
+ messages.append({"role": "user", "content": prompt})
58
+ print(f"[user] {prompt}", flush=True)
59
+ with st.chat_message("assistant", avatar='🤖'):
60
+ placeholder = st.empty()
61
+ for response in model.chat(tokenizer, messages, stream=True):
62
+ placeholder.markdown(response)
63
+ if torch.backends.mps.is_available():
64
+ torch.mps.empty_cache()
65
+ messages.append({"role": "assistant", "content": response})
66
+ print(json.dumps(messages, ensure_ascii=False), flush=True)
67
+
68
+ st.button("清空对话", on_click=clear_chat_history)
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()