WinterGYC commited on
Commit
d20d718
1 Parent(s): b5fba31
Files changed (3) hide show
  1. README.md +4 -4
  2. app.py +72 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Baichuan 13B Chat Int8
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: blue
6
  sdk: streamlit
7
- sdk_version: 1.21.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
  title: Baichuan 13B Chat Int8
3
+ emoji:
4
+ colorFrom: gray
5
+ colorTo: yellow
6
  sdk: streamlit
7
+ sdk_version: 1.24.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import streamlit as st
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from transformers.generation.utils import GenerationConfig
6
+
7
+
8
+ st.set_page_config(page_title="Baichuan-13B-Chat")
9
+ st.title("Baichuan-13B-Chat")
10
+
11
+ @st.cache_resource
12
+ def init_model():
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ "baichuan-inc/Baichuan-13B-Chat",
15
+ torch_dtype=torch.float16,
16
+ device_map="auto",
17
+ trust_remote_code=True
18
+ )
19
+ model.generation_config = GenerationConfig.from_pretrained(
20
+ "baichuan-inc/Baichuan-13B-Chat"
21
+ )
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ "baichuan-inc/Baichuan-13B-Chat",
24
+ use_fast=False,
25
+ trust_remote_code=True
26
+ )
27
+ model = model.quantize(8).cuda()
28
+ return model, tokenizer
29
+
30
+
31
+ def clear_chat_history():
32
+ del st.session_state.messages
33
+
34
+
35
+ def init_chat_history():
36
+ with st.chat_message("assistant", avatar='🤖'):
37
+ st.markdown("您好,我是百川大模型,很高兴为您服务🥰")
38
+
39
+ if "messages" in st.session_state:
40
+ for message in st.session_state.messages:
41
+ avatar = '🧑‍💻' if message["role"] == "user" else '🤖'
42
+ with st.chat_message(message["role"], avatar=avatar):
43
+ st.markdown(message["content"])
44
+ else:
45
+ st.session_state.messages = []
46
+
47
+ return st.session_state.messages
48
+
49
+
50
+ def main():
51
+ model, tokenizer = init_model()
52
+ messages = init_chat_history()
53
+
54
+ if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"):
55
+ with st.chat_message("user", avatar='🧑‍💻'):
56
+ st.markdown(prompt)
57
+ messages.append({"role": "user", "content": prompt})
58
+ print(f"[user] {prompt}", flush=True)
59
+ with st.chat_message("assistant", avatar='🤖'):
60
+ placeholder = st.empty()
61
+ for response in model.chat(tokenizer, messages, stream=True):
62
+ placeholder.markdown(response)
63
+ if torch.backends.mps.is_available():
64
+ torch.mps.empty_cache()
65
+ messages.append({"role": "assistant", "content": response})
66
+ print(json.dumps(messages, ensure_ascii=False), flush=True)
67
+
68
+ st.button("清空对话", on_click=clear_chat_history)
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ accelerate
2
+ colorama
3
+ cpm_kernels
4
+ sentencepiece
5
+ transformers_stream_generator