Spaces:
Running
Running
Upload 2 files
Browse files- app.py +49 -23
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import cohere
|
2 |
import streamlit as st
|
3 |
from streamlit.components.v1 import html
|
|
|
4 |
import re
|
5 |
import urllib.parse
|
6 |
|
@@ -26,6 +27,21 @@ if "messages" not in st.session_state:
|
|
26 |
st.session_state.messages = []
|
27 |
|
28 |
def get_ai_response(prompt, chat_history):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
with st.chat_message("ai"):
|
30 |
penalty_kwargs = {
|
31 |
"frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
|
@@ -41,16 +57,17 @@ def get_ai_response(prompt, chat_history):
|
|
41 |
p=p,
|
42 |
**penalty_kwargs
|
43 |
)
|
44 |
-
|
45 |
-
response = ""
|
46 |
placeholder = st.empty()
|
47 |
for event in stream:
|
48 |
if event.event_type == "text-generation":
|
49 |
content = event.text
|
50 |
-
response += content
|
51 |
-
placeholder.markdown(response)
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
def display_messages():
|
56 |
for i, message in enumerate(st.session_state.messages):
|
@@ -95,7 +112,7 @@ def display_messages():
|
|
95 |
|
96 |
if "edit_index" in st.session_state and st.session_state.edit_index == i:
|
97 |
with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
|
98 |
-
new_content = st.text_area("Edit message", value=st.session_state.messages[i]["text"])
|
99 |
col1, col2 = st.columns([1, 1])
|
100 |
with col1:
|
101 |
if st.form_submit_button("Save"):
|
@@ -106,16 +123,6 @@ def display_messages():
|
|
106 |
if st.form_submit_button("Cancel"):
|
107 |
del st.session_state.edit_index
|
108 |
st.rerun()
|
109 |
-
|
110 |
-
if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
|
111 |
-
if len(st.session_state.messages) > 0: # メッセージリストが空でないことを確認
|
112 |
-
prompt = st.session_state.messages[-1]["text"]
|
113 |
-
response = get_ai_response(prompt, st.session_state.messages[:-1])
|
114 |
-
st.session_state.messages.append({"role": "CHATBOT", "text": response})
|
115 |
-
st.session_state.retry_flag = False
|
116 |
-
st.rerun()
|
117 |
-
else:
|
118 |
-
st.session_state.retry_flag = False # retry_flagをFalseに設定して処理を続行
|
119 |
|
120 |
# Add sidebar for advanced settings
|
121 |
with st.sidebar:
|
@@ -130,7 +137,7 @@ with st.sidebar:
|
|
130 |
log_text += message["text"] + "\n\n"
|
131 |
log_text = log_text.rstrip("\n")
|
132 |
|
133 |
-
#
|
134 |
log_text_escaped = urllib.parse.quote(log_text)
|
135 |
|
136 |
copy_log_button_html = f"""
|
@@ -149,7 +156,7 @@ with st.sidebar:
|
|
149 |
|
150 |
st.header("Advanced Settings")
|
151 |
model = st.selectbox("Model", options=["command-r-plus", "command-r"], index=0)
|
152 |
-
preamble = st.text_area("Preamble", height=
|
153 |
temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.3, step=0.1)
|
154 |
k = st.slider("Top-K", min_value=0, max_value=500, value=0, step=1)
|
155 |
p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
|
@@ -190,16 +197,35 @@ with st.sidebar:
|
|
190 |
else:
|
191 |
st.warning("Please enter a valid API Key.")
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
display_messages()
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
if prompt := st.chat_input("What is up?"):
|
196 |
chat_history = st.session_state.messages.copy()
|
197 |
-
|
198 |
with st.chat_message("user"):
|
199 |
st.write(prompt)
|
200 |
-
|
201 |
-
response = get_ai_response(prompt, chat_history)
|
202 |
-
|
203 |
st.session_state.messages.append({"role": "USER", "text": prompt})
|
|
|
|
|
|
|
204 |
st.session_state.messages.append({"role": "CHATBOT", "text": response})
|
205 |
st.rerun()
|
|
|
1 |
import cohere
|
2 |
import streamlit as st
|
3 |
from streamlit.components.v1 import html
|
4 |
+
from streamlit_extras.stylable_container import stylable_container
|
5 |
import re
|
6 |
import urllib.parse
|
7 |
|
|
|
27 |
st.session_state.messages = []
|
28 |
|
29 |
def get_ai_response(prompt, chat_history):
|
30 |
+
st.session_state.is_streaming = True
|
31 |
+
|
32 |
+
with stylable_container(
|
33 |
+
key="stop_generating",
|
34 |
+
css_styles="""
|
35 |
+
button {
|
36 |
+
position: fixed;
|
37 |
+
bottom: 100px;
|
38 |
+
left: 50%;
|
39 |
+
transform: translateX(-50%);
|
40 |
+
}
|
41 |
+
""",
|
42 |
+
):
|
43 |
+
st.button("Stop generating")
|
44 |
+
|
45 |
with st.chat_message("ai"):
|
46 |
penalty_kwargs = {
|
47 |
"frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
|
|
|
57 |
p=p,
|
58 |
**penalty_kwargs
|
59 |
)
|
60 |
+
|
61 |
+
st.session_state.response = ""
|
62 |
placeholder = st.empty()
|
63 |
for event in stream:
|
64 |
if event.event_type == "text-generation":
|
65 |
content = event.text
|
66 |
+
st.session_state.response += content
|
67 |
+
placeholder.markdown(st.session_state.response)
|
68 |
+
|
69 |
+
st.session_state.is_streaming = False
|
70 |
+
return st.session_state.response
|
71 |
|
72 |
def display_messages():
|
73 |
for i, message in enumerate(st.session_state.messages):
|
|
|
112 |
|
113 |
if "edit_index" in st.session_state and st.session_state.edit_index == i:
|
114 |
with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
|
115 |
+
new_content = st.text_area("Edit message", height=200, value=st.session_state.messages[i]["text"])
|
116 |
col1, col2 = st.columns([1, 1])
|
117 |
with col1:
|
118 |
if st.form_submit_button("Save"):
|
|
|
123 |
if st.form_submit_button("Cancel"):
|
124 |
del st.session_state.edit_index
|
125 |
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
# Add sidebar for advanced settings
|
128 |
with st.sidebar:
|
|
|
137 |
log_text += message["text"] + "\n\n"
|
138 |
log_text = log_text.rstrip("\n")
|
139 |
|
140 |
+
# Encode the string to escape
|
141 |
log_text_escaped = urllib.parse.quote(log_text)
|
142 |
|
143 |
copy_log_button_html = f"""
|
|
|
156 |
|
157 |
st.header("Advanced Settings")
|
158 |
model = st.selectbox("Model", options=["command-r-plus", "command-r"], index=0)
|
159 |
+
preamble = st.text_area("Preamble", height=200)
|
160 |
temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.3, step=0.1)
|
161 |
k = st.slider("Top-K", min_value=0, max_value=500, value=0, step=1)
|
162 |
p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
|
|
|
197 |
else:
|
198 |
st.warning("Please enter a valid API Key.")
|
199 |
|
200 |
+
# After Stop generating
|
201 |
+
if "is_streaming" in st.session_state and st.session_state.is_streaming:
|
202 |
+
st.session_state.messages.append({"role": "CHATBOT", "text": st.session_state.response})
|
203 |
+
st.session_state.is_streaming = False
|
204 |
+
if "retry_flag" in st.session_state and st.session_state.retry_flag:
|
205 |
+
st.session_state.retry_flag = False
|
206 |
+
st.rerun()
|
207 |
+
|
208 |
display_messages()
|
209 |
|
210 |
+
# After Retry
|
211 |
+
if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
|
212 |
+
if len(st.session_state.messages) > 0:
|
213 |
+
prompt = st.session_state.messages[-1]["text"]
|
214 |
+
response = get_ai_response(prompt, st.session_state.messages[:-1])
|
215 |
+
st.session_state.messages.append({"role": "CHATBOT", "text": response})
|
216 |
+
st.session_state.retry_flag = False
|
217 |
+
st.rerun()
|
218 |
+
else:
|
219 |
+
st.session_state.retry_flag = False
|
220 |
+
|
221 |
if prompt := st.chat_input("What is up?"):
|
222 |
chat_history = st.session_state.messages.copy()
|
223 |
+
|
224 |
with st.chat_message("user"):
|
225 |
st.write(prompt)
|
|
|
|
|
|
|
226 |
st.session_state.messages.append({"role": "USER", "text": prompt})
|
227 |
+
|
228 |
+
response = get_ai_response(prompt, chat_history)
|
229 |
+
|
230 |
st.session_state.messages.append({"role": "CHATBOT", "text": response})
|
231 |
st.rerun()
|
requirements.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
cohere
|
|
|
|
1 |
+
cohere
|
2 |
+
streamlit-extras
|