Willder commited on
Commit
c90ffff
β€’
1 Parent(s): b09af9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -39
app.py CHANGED
@@ -1,57 +1,235 @@
 
1
  import os
2
 
3
  import openai
4
  import streamlit as st
5
- from streamlit_chat import message
6
 
7
- st.set_page_config(
8
- page_title="ChatGPT",
9
- page_icon=":robot:"
10
- )
11
 
12
- st.header("ChatGPT")
13
 
14
- if 'generated' not in st.session_state:
15
- st.session_state['generated'] = []
16
- if 'past' not in st.session_state:
17
- st.session_state['past'] = []
18
- if 'messages' not in st.session_state:
19
- st.session_state['messages'] = [
20
- {"role": "system", "content": "You are a helpful assistant that translates English to Chinese."}]
21
 
22
 
23
- def query(question):
24
- st.session_state['messages'].append({"role": "user", "content": question})
 
25
 
26
- openai.api_key = os.environ['OPENAI_API_KEY']
27
- response = openai.ChatCompletion.create(
28
- model="gpt-3.5-turbo",
29
- temperature=1.2,
30
- messages=st.session_state['messages'],
31
- max_tokens=2000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- answer = response['choices'][0]['message']['content']
35
- st.session_state['messages'].append({"role": "assistant", "content": answer})
36
- return answer
 
 
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- with st.form("my_form"):
40
- # Create a text input for the first field
41
- input_text = st.text_input("You: ", "Hello, how are you?", key="input")
42
 
43
- # Every form must have a submit button.
44
- # c1, c2 = st.columns([2, 2])
45
- submitted = st.form_submit_button("πŸ€– Submit")
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- if submitted and input_text:
48
- output = query(input_text, )
49
- if output:
50
- st.session_state.past.append(input_text)
51
- st.session_state.generated.append(output)
52
 
53
- if st.session_state['generated']:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- for i in range(len(st.session_state['generated']) - 1, -1, -1):
56
- message(st.session_state["generated"][i], key=str(i))
57
- message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
 
 
 
 
 
 
 
 
 
1
+ import base64
2
  import os
3
 
4
  import openai
5
  import streamlit as st
 
6
 
7
+ # from fpdf import FPDF
 
 
 
8
 
9
+ st.set_page_config(page_title="ChatGPT", page_icon="🌐")
10
 
11
+ MAIN = st.empty()
 
 
 
 
 
 
12
 
13
 
14
+ def create_download_link(val, filename):
15
+ b64 = base64.b64encode(val) # val looks like b'...'
16
+ return f'<a href="data:application/octet-stream;base64,{b64.decode()}" download="{filename}.pdf">Download file</a>'
17
 
18
+
19
+ @st.cache_resource
20
+ def init_openai_settings():
21
+ openai.api_key = os.getenv("OPENAI_API_KEY")
22
+
23
+
24
+ def init_session():
25
+ if not st.session_state.get("chats"):
26
+ st.session_state["chats"] = {}
27
+
28
+
29
+ def new_chat(chat_name):
30
+ if not st.session_state["chats"].get(chat_name):
31
+ st.session_state["chats"][chat_name] = {
32
+ "answer": [],
33
+ "question": [],
34
+ "messages": [
35
+ {"role": "system", "content": st.session_state["params"]["prompt"]}
36
+ ],
37
+ }
38
+ return chat_name
39
+
40
+
41
+ def switch_chat(chat_name):
42
+ if st.session_state.get("current_chat") != chat_name:
43
+ st.session_state["current_chat"] = chat_name
44
+ init_chat(chat_name)
45
+ st.stop()
46
+
47
+
48
+ def switch_chat2(chat_name):
49
+ if st.session_state.get("current_chat") != chat_name:
50
+ st.session_state["current_chat"] = chat_name
51
+ init_sidebar()
52
+ init_chat(chat_name)
53
+ st.stop()
54
+
55
+
56
+ def init_sidebar():
57
+ st.sidebar.title("ChatGPT")
58
+ chat_name_container = st.sidebar.container()
59
+ chat_config_expander = st.sidebar.expander('Chat configuration')
60
+ # export_pdf = st.sidebar.empty()
61
+
62
+ # chat config
63
+ st.session_state["params"] = dict()
64
+ # st.session_state['params']["api_key"] = chat_config_expander.text_input("API_KEY", placeholder="Please input openai key")
65
+ st.session_state["params"]["model"] = chat_config_expander.selectbox(
66
+ "Please select a model",
67
+ ["gpt-3.5-turbo"], # , "text-davinci-003"
68
+ help="ID of the model to use",
69
+ )
70
+ st.session_state["params"]["temperature"] = chat_config_expander.slider(
71
+ "Temperature",
72
+ min_value=0.0,
73
+ max_value=2.0,
74
+ value=1.2,
75
+ step=0.1,
76
+ format="%0.2f",
77
+ help="""What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.""",
78
  )
79
+ st.session_state["params"]["max_tokens"] = chat_config_expander.number_input(
80
+ "MAX_TOKENS",
81
+ value=2000,
82
+ step=1,
83
+ max_value=4000,
84
+ help="The maximum number of tokens to generate in the completion",
85
+ )
86
+ st.session_state["params"]["prompt"] = chat_config_expander.text_area(
87
+ "Prompts",
88
+ "You are a helpful assistant that answer questions as possible as you can.",
89
+ help="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.",
90
+ )
91
+ chat_config_expander.caption('Looking for help at https://platform.openai.com/docs/api-reference/chat')
92
 
93
+ new_chat_button = chat_name_container.button(
94
+ label="βž• New Chat", use_container_width=True
95
+ )
96
+ if new_chat_button:
97
+ new_chat_name = f"Chat{len(st.session_state['chats'])}"
98
+ st.session_state["current_chat"] = new_chat_name
99
+ new_chat(new_chat_name)
100
 
101
+ with st.sidebar.container():
102
+ for chat_name in st.session_state.get("chats", {}).keys():
103
+ if chat_name == st.session_state.get('current_chat'):
104
+ chat_name_container.button(
105
+ label='πŸ’¬ ' + chat_name,
106
+ on_click=switch_chat2,
107
+ key=chat_name,
108
+ args=(chat_name,),
109
+ type='primary',
110
+ use_container_width=True,
111
+ )
112
+ else:
113
+ chat_name_container.button(
114
+ label='πŸ’¬ ' + chat_name,
115
+ on_click=switch_chat2,
116
+ key=chat_name,
117
+ args=(chat_name,),
118
+ use_container_width=True,
119
+ )
120
 
121
+ if new_chat_button:
122
+ switch_chat(new_chat_name)
 
123
 
124
+ # Download pdf
125
+ # if st.session_state.get('current_chat'):
126
+ # chat = st.session_state["chats"][st.session_state['current_chat']]
127
+ # pdf = FPDF('p', 'mm', 'A4')
128
+ # pdf.add_page()
129
+ # pdf.set_font(family='Times', size=16)
130
+ # # pdf.cell(40, 50, txt='abcd.pdf')
131
+ #
132
+ # if chat["answer"]:
133
+ # for i in range(len(chat["answer"]) - 1, -1, -1):
134
+ # # message(chat["answer"][i], key=str(i))
135
+ # # message(chat['question'][i], is_user=True, key=str(i) + '_user')
136
+ # pdf.cell(40, txt=f"""YOU: {chat["question"][i]}""")
137
+ # pdf.cell(40, txt=f"""AI: {chat["answer"][i]}""")
138
+ #
139
+ # export_pdf.download_button('πŸ“€ PDF', data=pdf.output(dest='S').encode('latin-1'), file_name='abcd.pdf')
140
 
 
 
 
 
 
141
 
142
+ def init_chat(chat_name):
143
+ chat = st.session_state["chats"][chat_name]
144
+
145
+ # with MAIN.container():
146
+ answer_zoom = st.container()
147
+ ask_form = st.empty()
148
+
149
+ if len(chat['messages']) == 1 and st.session_state["params"]["prompt"]:
150
+ chat["messages"][0]['content'] = st.session_state["params"]["prompt"]
151
+
152
+ if chat['messages']:
153
+ # answer_zoom.markdown(f"""πŸ€– **Prompt:** {chat["messages"][0]['content']}""")
154
+ answer_zoom.info(f"""Prompt: {chat["messages"][0]['content']}""", icon="ℹ️")
155
+ answer_zoom.caption(f"""ℹ️ Prompt: {chat["messages"][0]['content']}""")
156
+ if chat["question"]:
157
+ for i in range(len(chat["question"])):
158
+ answer_zoom.markdown(f"""πŸ˜ƒ **YOU:** {chat["question"][i]}""")
159
+ if i < len(chat["answer"]):
160
+ answer_zoom.markdown(f"""πŸ€– **AI:** {chat["answer"][i]}""")
161
+
162
+ with ask_form.form(chat_name):
163
+ col1, col2 = st.columns([10, 1])
164
+ question_widget = col1.empty()
165
+ if not chat["question"]:
166
+ input_text = question_widget.text_area("πŸ˜ƒ You: ", "Hello, how are you?", key="input", max_chars=2000,
167
+ label_visibility='collapsed')
168
+ else:
169
+ input_text = question_widget.text_area("πŸ˜ƒ You: ", "", key="input", max_chars=2000,
170
+ label_visibility='collapsed')
171
+
172
+ submitted = col2.form_submit_button("πŸ›«")
173
+
174
+ if submitted and input_text:
175
+ chat["messages"].append({"role": "user", "content": input_text})
176
+ answer_zoom.markdown(f"""πŸ˜ƒ **YOU:** {input_text}""")
177
+
178
+ with st.spinner("Wait for responding..."):
179
+ answer = ask(chat["messages"])
180
+ answer_zoom.markdown(f"""πŸ€– **AI:** {answer}""")
181
+ chat["messages"].append({"role": "assistant", "content": answer})
182
+ if answer:
183
+ chat["question"].append(input_text)
184
+ chat["answer"].append(answer)
185
+
186
+ question_widget.text_area("πŸ˜ƒ You: ", "", key="input-1", max_chars=2000,
187
+ label_visibility='collapsed')
188
+
189
+
190
+ def init_css():
191
+ """try to fixed input field"""
192
+ st.markdown(
193
+ """
194
+ <style>
195
+
196
+ div[data-testid="stVerticalBlock"] > div[style*="flex-direction: column;"] > [data-testid="stVerticalBlock"] > [data-testid="stForm"] {
197
+ border: 20px groove red;
198
+ position: fixed;
199
+ width: 100%;
200
+
201
+ flex-direction: column;
202
+ flex-grow: 5;
203
+ overflow: auto;
204
+ }
205
+ </style>
206
+ """,
207
+ unsafe_allow_html=True,
208
+ )
209
+
210
+
211
+ def ask(messages):
212
+ if st.session_state["params"]["model"] == 'gpt-3.5-turbo':
213
+ response = openai.ChatCompletion.create(
214
+ model=st.session_state["params"]["model"],
215
+ temperature=st.session_state["params"]["temperature"],
216
+ messages=messages,
217
+ max_tokens=st.session_state["params"]["max_tokens"],
218
+ )
219
+ answer = response["choices"][0]["message"]["content"]
220
+ else:
221
+ raise NotImplementedError('Not implemented yet!')
222
+ return answer
223
+
224
 
225
+ if __name__ == "__main__":
226
+ print("loading")
227
+ init_openai_settings()
228
+ # init_css()
229
+ init_session()
230
+ init_sidebar()
231
+ if st.session_state.get("current_chat"):
232
+ print("current_chat: ", st.session_state.get("current_chat"))
233
+ init_chat((st.session_state["current_chat"]))
234
+ if len(st.session_state["chats"]) == 0:
235
+ switch_chat(new_chat(f"Chat{len(st.session_state['chats'])}"))