Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,24 +1,23 @@
|
|
1 |
-
import base64
|
2 |
import os
|
|
|
|
|
3 |
|
4 |
import openai
|
|
|
5 |
import streamlit as st
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
st.set_page_config(page_title="ChatGPT", page_icon="🌐")
|
10 |
-
|
11 |
-
MAIN = st.empty()
|
12 |
|
|
|
13 |
|
14 |
-
def create_download_link(val, filename):
|
15 |
-
b64 = base64.b64encode(val) # val looks like b'...'
|
16 |
-
return f'<a href="data:application/octet-stream;base64,{b64.decode()}" download="{filename}.pdf">Download file</a>'
|
17 |
|
|
|
18 |
|
19 |
-
@st.
|
20 |
def init_openai_settings():
|
21 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
22 |
|
23 |
|
24 |
def init_session():
|
@@ -36,6 +35,8 @@ def new_chat(chat_name):
|
|
36 |
"messages": [
|
37 |
{"role": "system", "content": st.session_state["params"]["prompt"]}
|
38 |
],
|
|
|
|
|
39 |
}
|
40 |
return chat_name
|
41 |
|
@@ -43,191 +44,333 @@ def new_chat(chat_name):
|
|
43 |
def switch_chat(chat_name):
|
44 |
if st.session_state.get("current_chat") != chat_name:
|
45 |
st.session_state["current_chat"] = chat_name
|
46 |
-
|
47 |
st.stop()
|
48 |
|
49 |
|
50 |
-
def
|
51 |
if st.session_state.get("current_chat") != chat_name:
|
52 |
st.session_state["current_chat"] = chat_name
|
53 |
-
|
54 |
-
|
55 |
st.stop()
|
56 |
|
57 |
|
58 |
-
def
|
59 |
-
st.
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
st.session_state["params"] = dict()
|
66 |
-
|
67 |
-
st.session_state["params"]["model"] = chat_config_expander.selectbox(
|
68 |
"Please select a model",
|
69 |
-
["gpt-3.5-turbo"], # , "
|
70 |
help="ID of the model to use",
|
71 |
)
|
72 |
-
st.session_state["params"]["temperature"] =
|
73 |
"Temperature",
|
74 |
min_value=0.0,
|
75 |
max_value=2.0,
|
76 |
value=1.2,
|
77 |
step=0.1,
|
78 |
format="%0.2f",
|
79 |
-
help="
|
80 |
)
|
81 |
-
st.session_state["params"]["max_tokens"] =
|
82 |
-
"
|
83 |
value=2000,
|
84 |
step=1,
|
85 |
min_value=100,
|
86 |
max_value=4000,
|
87 |
help="The maximum number of tokens to generate in the completion",
|
88 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
93 |
help="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.",
|
94 |
)
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
-
new_chat_button = chat_name_container.button(
|
98 |
-
label="➕ New Chat"
|
99 |
-
) # , use_container_width=True
|
100 |
-
if new_chat_button:
|
101 |
-
new_chat_name = f"Chat{len(st.session_state['chats'])}"
|
102 |
-
st.session_state["current_chat"] = new_chat_name
|
103 |
-
new_chat(new_chat_name)
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
chat_name_container.button(
|
109 |
-
label='💬 ' + chat_name,
|
110 |
-
on_click=switch_chat2,
|
111 |
-
key=chat_name,
|
112 |
-
args=(chat_name,),
|
113 |
-
type='primary',
|
114 |
-
# use_container_width=True,
|
115 |
-
)
|
116 |
-
else:
|
117 |
-
chat_name_container.button(
|
118 |
-
label='💬 ' + chat_name,
|
119 |
-
on_click=switch_chat2,
|
120 |
-
key=chat_name,
|
121 |
-
args=(chat_name,),
|
122 |
-
# use_container_width=True,
|
123 |
-
)
|
124 |
|
125 |
-
|
126 |
-
switch_chat(new_chat_name)
|
127 |
|
128 |
-
# Download pdf
|
129 |
-
# if st.session_state.get('current_chat'):
|
130 |
-
# chat = st.session_state["chats"][st.session_state['current_chat']]
|
131 |
-
# pdf = FPDF('p', 'mm', 'A4')
|
132 |
-
# pdf.add_page()
|
133 |
-
# pdf.set_font(family='Times', size=16)
|
134 |
-
# # pdf.cell(40, 50, txt='abcd.pdf')
|
135 |
-
#
|
136 |
-
# if chat["answer"]:
|
137 |
-
# for i in range(len(chat["answer"]) - 1, -1, -1):
|
138 |
-
# # message(chat["answer"][i], key=str(i))
|
139 |
-
# # message(chat['question'][i], is_user=True, key=str(i) + '_user')
|
140 |
-
# pdf.cell(40, txt=f"""YOU: {chat["question"][i]}""")
|
141 |
-
# pdf.cell(40, txt=f"""AI: {chat["answer"][i]}""")
|
142 |
-
#
|
143 |
-
# export_pdf.download_button('📤 PDF', data=pdf.output(dest='S').encode('latin-1'), file_name='abcd.pdf')
|
144 |
-
|
145 |
-
|
146 |
-
def init_chat(chat_name):
|
147 |
-
chat = st.session_state["chats"][chat_name]
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
if len(chat['messages']) == 1 and st.session_state["params"]["prompt"]:
|
154 |
chat["messages"][0]['content'] = st.session_state["params"]["prompt"]
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
for i in range(len(chat["question"])):
|
162 |
-
answer_zoom.markdown(f"""😃 **YOU:** {chat["question"][i]}""")
|
163 |
-
if i < len(chat["answer"]):
|
164 |
-
answer_zoom.markdown(f"""🤖 **AI:** {chat["answer"][i]}""")
|
165 |
|
166 |
-
|
167 |
-
col1, col2 = st.columns([10, 1])
|
168 |
-
input_text = col1.text_area("😃 You: ", "Hello, how are you?", key="input", max_chars=2000,
|
169 |
-
label_visibility='collapsed')
|
170 |
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
def init_css():
|
189 |
-
"""try to fixed input field"""
|
190 |
-
st.markdown(
|
191 |
-
"""
|
192 |
-
<style>
|
193 |
-
|
194 |
-
div[data-testid="stVerticalBlock"] > div[style*="flex-direction: column;"] > [data-testid="stVerticalBlock"] > [data-testid="stForm"] {
|
195 |
-
border: 20px groove red;
|
196 |
-
position: fixed;
|
197 |
-
width: 100%;
|
198 |
-
|
199 |
-
flex-direction: column;
|
200 |
-
flex-grow: 5;
|
201 |
-
overflow: auto;
|
202 |
-
}
|
203 |
-
</style>
|
204 |
-
""",
|
205 |
-
unsafe_allow_html=True,
|
206 |
-
)
|
207 |
|
208 |
|
209 |
-
def
|
210 |
-
if st.session_state["params"]["model"]
|
211 |
response = openai.ChatCompletion.create(
|
212 |
model=st.session_state["params"]["model"],
|
213 |
temperature=st.session_state["params"]["temperature"],
|
214 |
messages=messages,
|
|
|
215 |
max_tokens=st.session_state["params"]["max_tokens"],
|
216 |
)
|
217 |
-
answer = response["choices"][0]["message"]["content"]
|
218 |
else:
|
219 |
raise NotImplementedError('Not implemented yet!')
|
220 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
|
223 |
if __name__ == "__main__":
|
224 |
-
print("
|
225 |
init_openai_settings()
|
226 |
-
# init_css()
|
227 |
init_session()
|
228 |
-
|
229 |
if st.session_state.get("current_chat"):
|
230 |
-
|
231 |
-
init_chat((st.session_state["current_chat"]))
|
232 |
if len(st.session_state["chats"]) == 0:
|
233 |
switch_chat(new_chat(f"Chat{len(st.session_state['chats'])}"))
|
|
|
|
|
1 |
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
|
5 |
import openai
|
6 |
+
import pandas as pd
|
7 |
import streamlit as st
|
8 |
|
9 |
+
from streamlit.elements.utils import _shown_default_value_warning
|
|
|
|
|
|
|
|
|
10 |
|
11 |
+
_shown_default_value_warning = True # https://discuss.streamlit.io/t/why-do-default-values-cause-a-session-state-warning/15485/21
|
12 |
|
|
|
|
|
|
|
13 |
|
14 |
+
st.set_page_config(page_title="ChatGPT", page_icon="🌐")
|
15 |
|
16 |
+
@st.cache_resource
|
17 |
def init_openai_settings():
|
18 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
19 |
+
if os.getenv("OPENAI_PROXY"):
|
20 |
+
openai.proxy = os.getenv("OPENAI_PROXY")
|
21 |
|
22 |
|
23 |
def init_session():
|
|
|
35 |
"messages": [
|
36 |
{"role": "system", "content": st.session_state["params"]["prompt"]}
|
37 |
],
|
38 |
+
"is_delete": False,
|
39 |
+
"display_name": chat_name,
|
40 |
}
|
41 |
return chat_name
|
42 |
|
|
|
44 |
def switch_chat(chat_name):
|
45 |
if st.session_state.get("current_chat") != chat_name:
|
46 |
st.session_state["current_chat"] = chat_name
|
47 |
+
render_chat(chat_name)
|
48 |
st.stop()
|
49 |
|
50 |
|
51 |
+
def switch_chat_name(chat_name):
|
52 |
if st.session_state.get("current_chat") != chat_name:
|
53 |
st.session_state["current_chat"] = chat_name
|
54 |
+
render_sidebar()
|
55 |
+
render_chat(chat_name)
|
56 |
st.stop()
|
57 |
|
58 |
|
59 |
+
def delete_chat(chat_name):
|
60 |
+
if chat_name in st.session_state['chats']:
|
61 |
+
st.session_state['chats'][chat_name]['is_delete'] = True
|
62 |
+
|
63 |
+
current_chats = [chat for chat, value in st.session_state['chats'].items() if not value['is_delete']]
|
64 |
+
if len(current_chats) == 0:
|
65 |
+
switch_chat(new_chat(f"Chat{len(st.session_state['chats'])}"))
|
66 |
+
st.stop()
|
67 |
+
|
68 |
+
if st.session_state["current_chat"] == chat_name:
|
69 |
+
del st.session_state["current_chat"]
|
70 |
+
switch_chat_name(current_chats[0])
|
71 |
+
|
72 |
+
|
73 |
+
def edit_chat(chat_name, zone):
|
74 |
+
def edit():
|
75 |
+
if not st.session_state['edited_name']:
|
76 |
+
print('name is empty!')
|
77 |
+
return None
|
78 |
+
|
79 |
+
if (st.session_state['edited_name'] != chat_name
|
80 |
+
and st.session_state['edited_name'] in st.session_state['chats']):
|
81 |
+
print('name is duplicated!')
|
82 |
+
return None
|
83 |
+
|
84 |
+
if st.session_state['edited_name'] == chat_name:
|
85 |
+
print('name is not modified!')
|
86 |
+
return None
|
87 |
+
|
88 |
+
st.session_state['chats'][chat_name]['display_name'] = st.session_state['edited_name']
|
89 |
+
|
90 |
+
edit_zone = zone.empty()
|
91 |
+
time.sleep(0.1)
|
92 |
+
with edit_zone.container():
|
93 |
+
st.text_input('New Name', st.session_state['chats'][chat_name]['display_name'], key='edited_name')
|
94 |
+
column1, _, column2 = st.columns([1, 5, 1])
|
95 |
+
column1.button('✅', on_click=edit)
|
96 |
+
column2.button('❌')
|
97 |
+
|
98 |
+
|
99 |
+
def render_sidebar_chat_management(zone):
|
100 |
+
new_chat_button = zone.button(label="➕ New Chat", use_container_width=True)
|
101 |
+
if new_chat_button:
|
102 |
+
new_chat_name = f"Chat{len(st.session_state['chats'])}"
|
103 |
+
st.session_state["current_chat"] = new_chat_name
|
104 |
+
new_chat(new_chat_name)
|
105 |
|
106 |
+
with st.sidebar.container():
|
107 |
+
for chat_name in st.session_state["chats"].keys():
|
108 |
+
if st.session_state['chats'][chat_name]['is_delete']:
|
109 |
+
continue
|
110 |
+
if chat_name == st.session_state.get('current_chat'):
|
111 |
+
column1, column2, column3 = zone.columns([7, 1, 1])
|
112 |
+
column1.button(
|
113 |
+
label='💬 ' + st.session_state['chats'][chat_name]['display_name'],
|
114 |
+
on_click=switch_chat_name,
|
115 |
+
key=chat_name,
|
116 |
+
args=(chat_name,),
|
117 |
+
type='primary',
|
118 |
+
use_container_width=True,
|
119 |
+
)
|
120 |
+
column2.button(label='📝', key='edit', on_click=edit_chat, args=(chat_name, zone))
|
121 |
+
column3.button(label='🗑️', key='remove', on_click=delete_chat, args=(chat_name,))
|
122 |
+
else:
|
123 |
+
zone.button(
|
124 |
+
label='💬 ' + st.session_state['chats'][chat_name]['display_name'],
|
125 |
+
on_click=switch_chat_name,
|
126 |
+
key=chat_name,
|
127 |
+
args=(chat_name,),
|
128 |
+
use_container_width=True,
|
129 |
+
)
|
130 |
+
|
131 |
+
if new_chat_button:
|
132 |
+
switch_chat(new_chat_name)
|
133 |
+
|
134 |
+
|
135 |
+
def render_sidebar_gpt_config_tab(zone):
|
136 |
st.session_state["params"] = dict()
|
137 |
+
st.session_state["params"]["model"] = zone.selectbox(
|
|
|
138 |
"Please select a model",
|
139 |
+
["gpt-3.5-turbo"], # , "gpt-4"
|
140 |
help="ID of the model to use",
|
141 |
)
|
142 |
+
st.session_state["params"]["temperature"] = zone.slider(
|
143 |
"Temperature",
|
144 |
min_value=0.0,
|
145 |
max_value=2.0,
|
146 |
value=1.2,
|
147 |
step=0.1,
|
148 |
format="%0.2f",
|
149 |
+
help="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.",
|
150 |
)
|
151 |
+
st.session_state["params"]["max_tokens"] = zone.slider(
|
152 |
+
"Max Tokens",
|
153 |
value=2000,
|
154 |
step=1,
|
155 |
min_value=100,
|
156 |
max_value=4000,
|
157 |
help="The maximum number of tokens to generate in the completion",
|
158 |
)
|
159 |
+
st.session_state["params"]["stream"] = zone.checkbox(
|
160 |
+
"Steaming output",
|
161 |
+
value=True,
|
162 |
+
help="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message",
|
163 |
+
)
|
164 |
+
zone.caption('Looking for help at https://platform.openai.com/docs/api-reference/chat')
|
165 |
+
|
166 |
|
167 |
+
def render_sidebar_prompt_config_tab(zone):
|
168 |
+
prompt_text = zone.empty()
|
169 |
+
st.session_state["params"]["prompt"] = prompt_text.text_area(
|
170 |
+
"System Prompt",
|
171 |
+
"You are a helpful assistant that translates answer from English to Chinese.",
|
172 |
help="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.",
|
173 |
)
|
174 |
+
template = zone.selectbox('Loading From Prompt Template', load_prompt_templates())
|
175 |
+
if template:
|
176 |
+
prompts_df = load_prompts(template)
|
177 |
+
actor = zone.selectbox('Loading Prompts', prompts_df.index.tolist())
|
178 |
+
if actor:
|
179 |
+
st.session_state["params"]["prompt"] = prompt_text.text_area(
|
180 |
+
"System Prompt",
|
181 |
+
prompts_df.loc[actor].prompt,
|
182 |
+
help="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.",
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
def render_download_zone(zone):
|
187 |
+
from io import BytesIO, StringIO
|
188 |
+
if not st.session_state.get('current_chat'):
|
189 |
+
return
|
190 |
+
|
191 |
+
chat = st.session_state['chats'][st.session_state['current_chat']]
|
192 |
+
col1, col2 = zone.columns([1, 1])
|
193 |
+
|
194 |
+
chat_messages = ['# ' + chat['display_name']]
|
195 |
+
if chat["question"]:
|
196 |
+
for i in range(len(chat["question"])):
|
197 |
+
chat_messages.append(f"""😃 **YOU:** {chat["question"][i]}""")
|
198 |
+
if i < len(chat["answer"]):
|
199 |
+
chat_messages.append(f"""🤖 **AI:** {chat["answer"][i]}""")
|
200 |
+
col1.download_button('📤 Markdown', '\n'.join(chat_messages).encode('utf-8'), file_name=f"{chat['display_name']}.md", help="Download messages to a markdown file", use_container_width=True)
|
201 |
+
|
202 |
+
tables = []
|
203 |
+
for answer in chat["answer"]:
|
204 |
+
filter_table_str = '\n'.join([m.strip() for m in answer.split('\n') if m.strip().startswith('| ') or m == ''])
|
205 |
+
try:
|
206 |
+
tables.extend([pd.read_table(StringIO(filter_table_str.replace(' ', '')), sep='|').dropna(axis=1, how='all').iloc[1:] for m in filter_table_str.split('\n\n')])
|
207 |
+
except Exception as e:
|
208 |
+
print(e)
|
209 |
+
if tables:
|
210 |
+
buffer = BytesIO()
|
211 |
+
with pd.ExcelWriter(buffer) as writer:
|
212 |
+
for index, table in enumerate(tables):
|
213 |
+
table.to_excel(writer, sheet_name=str(index + 1), index=False)
|
214 |
+
col2.download_button('📉 Excel', buffer.getvalue(), file_name=f"{chat['display_name']}.xlsx", help="Download tables to a excel file", use_container_width=True)
|
215 |
+
|
216 |
+
|
217 |
+
def render_sidebar():
|
218 |
+
chat_name_container = st.sidebar.container()
|
219 |
+
chat_config_expander = st.sidebar.expander('Chat configuration', True)
|
220 |
+
tab_gpt, tab_prompt = chat_config_expander.tabs(['ChatGPT', 'Prompt'])
|
221 |
+
download_zone = st.sidebar.empty()
|
222 |
+
|
223 |
+
render_sidebar_gpt_config_tab(tab_gpt)
|
224 |
+
render_sidebar_prompt_config_tab(tab_prompt)
|
225 |
+
render_sidebar_chat_management(chat_name_container)
|
226 |
+
render_download_zone(download_zone)
|
227 |
+
|
228 |
+
|
229 |
+
def render_user_message(message, zone):
|
230 |
+
col1, col2 = zone.columns([1,8])
|
231 |
+
col1.markdown("😃 **YOU:**")
|
232 |
+
col2.markdown(message)
|
233 |
+
|
234 |
+
|
235 |
+
def render_ai_message(message, zone):
|
236 |
+
col1, col2 = zone.columns([1,8])
|
237 |
+
col1.markdown("🤖 **AI:**")
|
238 |
+
col2.markdown(message)
|
239 |
+
|
240 |
+
|
241 |
+
def render_history_answer(chat, zone):
|
242 |
+
zone.empty()
|
243 |
+
time.sleep(0.1) # https://github.com/streamlit/streamlit/issues/5044
|
244 |
+
with zone.container():
|
245 |
+
if chat['messages']:
|
246 |
+
st.caption(f"""ℹ️ Prompt: {chat["messages"][0]['content']}""")
|
247 |
+
if chat["question"]:
|
248 |
+
for i in range(len(chat["question"])):
|
249 |
+
render_user_message(chat["question"][i], st)
|
250 |
+
if i < len(chat["answer"]):
|
251 |
+
render_ai_message(chat["answer"][i], st)
|
252 |
+
|
253 |
+
|
254 |
+
def render_last_answer(question, chat, zone):
|
255 |
+
answer_zone = zone.empty()
|
256 |
+
|
257 |
+
chat["messages"].append({"role": "user", "content": question})
|
258 |
+
chat["question"].append(question)
|
259 |
+
if st.session_state["params"]["stream"]:
|
260 |
+
answer = ""
|
261 |
+
chat["answer"].append(answer)
|
262 |
+
chat["messages"].append({"role": "assistant", "content": answer})
|
263 |
+
for response in get_openai_response(chat["messages"]):
|
264 |
+
answer += response["choices"][0]['delta'].get("content", '')
|
265 |
+
chat["answer"][-1] = answer
|
266 |
+
chat["messages"][-1] = {"role": "assistant", "content": answer}
|
267 |
+
render_ai_message(answer, answer_zone)
|
268 |
+
else:
|
269 |
+
with st.spinner("Wait for responding..."):
|
270 |
+
answer = get_openai_response(chat["messages"])["choices"][0]["message"]["content"]
|
271 |
+
chat["answer"].append(answer)
|
272 |
+
chat["messages"].append({"role": "assistant", "content": answer})
|
273 |
+
render_ai_message(answer, zone)
|
274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
+
def render_stop_generate_button(zone):
|
277 |
+
def stop():
|
278 |
+
st.session_state['regenerate'] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
|
280 |
+
zone.columns((2, 1, 2))[1].button('□ Stop', on_click=stop)
|
|
|
281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
+
def render_regenerate_button(chat, zone):
|
284 |
+
def regenerate():
|
285 |
+
chat["messages"].pop(-1)
|
286 |
+
chat["messages"].pop(-1)
|
287 |
+
chat["answer"].pop(-1)
|
288 |
+
st.session_state['regenerate'] = True
|
289 |
+
st.session_state['last_question'] = chat["question"].pop(-1)
|
290 |
+
|
291 |
+
zone.columns((2, 1, 2))[1].button('🔄Regenerate', type='primary', on_click=regenerate)
|
292 |
|
293 |
+
|
294 |
+
def render_chat(chat_name):
|
295 |
+
def handle_ask():
|
296 |
+
if st.session_state['input']:
|
297 |
+
re_generate_zone.empty()
|
298 |
+
render_user_message(st.session_state['input'], last_question_zone)
|
299 |
+
render_stop_generate_button(stop_generate_zone)
|
300 |
+
render_last_answer(st.session_state['input'], chat, last_answer_zone)
|
301 |
+
st.session_state['input'] = ''
|
302 |
+
|
303 |
+
if chat_name not in st.session_state["chats"]:
|
304 |
+
st.error(f'{chat_name} is not exist')
|
305 |
+
return
|
306 |
+
chat = st.session_state["chats"][chat_name]
|
307 |
+
if chat['is_delete']:
|
308 |
+
st.error(f"{chat_name} is deleted")
|
309 |
+
st.stop()
|
310 |
if len(chat['messages']) == 1 and st.session_state["params"]["prompt"]:
|
311 |
chat["messages"][0]['content'] = st.session_state["params"]["prompt"]
|
312 |
|
313 |
+
conversation_zone = st.container()
|
314 |
+
history_zone = conversation_zone.empty()
|
315 |
+
last_question_zone = conversation_zone.empty()
|
316 |
+
last_answer_zone = conversation_zone.empty()
|
317 |
+
ask_form_zone = st.empty()
|
|
|
|
|
|
|
|
|
318 |
|
319 |
+
render_history_answer(chat, history_zone)
|
|
|
|
|
|
|
320 |
|
321 |
+
ask_form = ask_form_zone.form(chat_name)
|
322 |
+
col1, col2 = ask_form.columns([10, 1])
|
323 |
+
col1.text_area("😃 You: ", "Hello, how are you?",
|
324 |
+
key="input",
|
325 |
+
max_chars=2000,
|
326 |
+
label_visibility='collapsed')
|
327 |
|
328 |
+
col2.form_submit_button("🚀", on_click=handle_ask)
|
329 |
+
stop_generate_zone = conversation_zone.empty()
|
330 |
+
re_generate_zone = conversation_zone.empty()
|
331 |
|
332 |
+
if st.session_state.get('regenerate'):
|
333 |
+
render_user_message(st.session_state['last_question'], last_question_zone)
|
334 |
+
render_stop_generate_button(stop_generate_zone)
|
335 |
+
render_last_answer(st.session_state['last_question'], chat, last_answer_zone)
|
336 |
+
st.session_state['regenerate'] = False
|
337 |
+
|
338 |
+
if chat["answer"]:
|
339 |
+
stop_generate_zone.empty()
|
340 |
+
render_regenerate_button(chat, re_generate_zone)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
|
343 |
+
def get_openai_response(messages):
|
344 |
+
if st.session_state["params"]["model"] in {'gpt-3.5-turbo', 'gpt4'}:
|
345 |
response = openai.ChatCompletion.create(
|
346 |
model=st.session_state["params"]["model"],
|
347 |
temperature=st.session_state["params"]["temperature"],
|
348 |
messages=messages,
|
349 |
+
stream=st.session_state["params"]["stream"],
|
350 |
max_tokens=st.session_state["params"]["max_tokens"],
|
351 |
)
|
|
|
352 |
else:
|
353 |
raise NotImplementedError('Not implemented yet!')
|
354 |
+
return response
|
355 |
+
|
356 |
+
|
357 |
+
def load_prompt_templates():
|
358 |
+
path = Path(__file__).parent / "templates"
|
359 |
+
return [f.name for f in path.glob("*.json")]
|
360 |
+
|
361 |
+
|
362 |
+
def load_prompts(template_name):
|
363 |
+
if template_name:
|
364 |
+
path = Path(__file__).parent / "templates" / template_name
|
365 |
+
return pd.read_json(path).drop_duplicates(subset='act').set_index('act') # act, prompt
|
366 |
|
367 |
|
368 |
if __name__ == "__main__":
|
369 |
+
print("---- page reloading ----")
|
370 |
init_openai_settings()
|
|
|
371 |
init_session()
|
372 |
+
render_sidebar()
|
373 |
if st.session_state.get("current_chat"):
|
374 |
+
render_chat(st.session_state["current_chat"])
|
|
|
375 |
if len(st.session_state["chats"]) == 0:
|
376 |
switch_chat(new_chat(f"Chat{len(st.session_state['chats'])}"))
|