ericzzz commited on
Commit
0694801
1 Parent(s): fc843d0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ TOKEN_LIMIT = 2048
6
+ TEMPERATURE = 0.7
7
+ REPETITION_PENALTY = 1.05
8
+ MAX_NEW_TOKENS = 500
9
+ MODEL_NAME = "ericzzz/falcon-rw-1b-chat"
10
+
11
+
12
+ if "chat_history" not in st.session_state:
13
+ st.session_state.chat_history = []
14
+
15
+ torch.set_grad_enabled(False)
16
+
17
+
18
+ @st.cache_resource()
19
+ def load_model():
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16
23
+ )
24
+ return tokenizer, model
25
+
26
+
27
+ # def chat_func(tokenizer, model, chat_history):
28
+ # input_ids = tokenizer.apply_chat_template(
29
+ # chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt"
30
+ # ).to(model.device)
31
+ # output_tokens = model.generate(
32
+ # input_ids,
33
+ # do_sample=True,
34
+ # temperature=TEMPERATURE,
35
+ # repetition_penalty=REPETITION_PENALTY,
36
+ # max_new_tokens=MAX_NEW_TOKENS,
37
+ # )
38
+ # output_text = tokenizer.decode(
39
+ # output_tokens[0][len(input_ids[0]) :], skip_special_tokens=True
40
+ # )
41
+ # return output_text
42
+
43
+
44
+ def chat_func_stream(tokenizer, model, chat_history, streamer):
45
+ input_ids = tokenizer.apply_chat_template(
46
+ chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt"
47
+ ).to(model.device)
48
+ # check input length
49
+ if len(input_ids[0]) > TOKEN_LIMIT:
50
+ st.warning(
51
+ f"We have limited computation power. Please keep you input within {TOKEN_LIMIT} tokens."
52
+ )
53
+ st.session_state.chat_history = st.session_state.chat_history[:-1]
54
+ return
55
+ model.generate(
56
+ input_ids,
57
+ do_sample=True,
58
+ temperature=TEMPERATURE,
59
+ repetition_penalty=REPETITION_PENALTY,
60
+ max_new_tokens=MAX_NEW_TOKENS,
61
+ streamer=streamer,
62
+ )
63
+ return
64
+
65
+
66
+ def show_chat_message(contrainer, chat_message):
67
+ with contrainer:
68
+ with st.chat_message(chat_message["role"]):
69
+ st.write(chat_message["content"])
70
+
71
+
72
+ class ResponseStreamer:
73
+ def __init__(self, tokenizer, container, chat_history):
74
+ self.tokenizer = tokenizer
75
+ self.container = container
76
+ self.chat_history = chat_history
77
+
78
+ self.first_call_to_put = True
79
+ self.current_response = ""
80
+ with self.container:
81
+ self.placeholder = st.empty() # placeholder to save streamed message
82
+
83
+ def put(self, new_token):
84
+ # do not write input tokens
85
+ if self.first_call_to_put:
86
+ self.first_call_to_put = False
87
+ return
88
+ # decode current token and accumulate current_response
89
+ decoded = self.tokenizer.decode(new_token[0], skip_special_tokens=True)
90
+ self.current_response += decoded
91
+ # display the stramed message
92
+ show_chat_message(
93
+ self.placeholder.container(),
94
+ {"role": "assistant", "content": self.current_response},
95
+ )
96
+
97
+ def end(self):
98
+ # save assistant message
99
+ self.chat_history.append(
100
+ {"role": "assistant", "content": self.current_response}
101
+ )
102
+ # clean up states (actually not needed as the instance will get recreated)
103
+ self.first_call_to_put = True
104
+ self.current_response = ""
105
+
106
+
107
+ tokenizer, model = load_model()
108
+ chat_messages_container = st.container()
109
+
110
+ for msg in st.session_state.chat_history:
111
+ show_chat_message(chat_messages_container, msg)
112
+
113
+ user_input = st.chat_input()
114
+ if user_input:
115
+ new_user_message = {"role": "user", "content": user_input}
116
+ st.session_state.chat_history.append(new_user_message)
117
+ show_chat_message(chat_messages_container, new_user_message)
118
+
119
+ # assistant_message = chat_func(tokenizer, model, st.session_state.chat_history)
120
+ # assistant_message = {"role": "assistant", "content": assistant_message}
121
+ # st.session_state.chat_history.append(assistant_message)
122
+ # show_chat_message(chat_messages_container, assistant_message)
123
+
124
+ streamer = ResponseStreamer(
125
+ tokenizer=tokenizer,
126
+ container=chat_messages_container,
127
+ chat_history=st.session_state.chat_history,
128
+ )
129
+ chat_func_stream(tokenizer, model, st.session_state.chat_history, streamer)