Nyanfa commited on
Commit
4f4cd9b
1 Parent(s): 5b4675a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +49 -23
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import cohere
2
  import streamlit as st
3
  from streamlit.components.v1 import html
 
4
  import re
5
  import urllib.parse
6
 
@@ -26,6 +27,21 @@ if "messages" not in st.session_state:
26
  st.session_state.messages = []
27
 
28
  def get_ai_response(prompt, chat_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  with st.chat_message("ai"):
30
  penalty_kwargs = {
31
  "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
@@ -41,16 +57,17 @@ def get_ai_response(prompt, chat_history):
41
  p=p,
42
  **penalty_kwargs
43
  )
44
-
45
- response = ""
46
  placeholder = st.empty()
47
  for event in stream:
48
  if event.event_type == "text-generation":
49
  content = event.text
50
- response += content
51
- placeholder.markdown(response)
52
-
53
- return response
 
54
 
55
  def display_messages():
56
  for i, message in enumerate(st.session_state.messages):
@@ -95,7 +112,7 @@ def display_messages():
95
 
96
  if "edit_index" in st.session_state and st.session_state.edit_index == i:
97
  with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
98
- new_content = st.text_area("Edit message", value=st.session_state.messages[i]["text"])
99
  col1, col2 = st.columns([1, 1])
100
  with col1:
101
  if st.form_submit_button("Save"):
@@ -106,16 +123,6 @@ def display_messages():
106
  if st.form_submit_button("Cancel"):
107
  del st.session_state.edit_index
108
  st.rerun()
109
-
110
- if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
111
- if len(st.session_state.messages) > 0: # メッセージリストが空でないことを確認
112
- prompt = st.session_state.messages[-1]["text"]
113
- response = get_ai_response(prompt, st.session_state.messages[:-1])
114
- st.session_state.messages.append({"role": "CHATBOT", "text": response})
115
- st.session_state.retry_flag = False
116
- st.rerun()
117
- else:
118
- st.session_state.retry_flag = False # retry_flagをFalseに設定して処理を続行
119
 
120
  # Add sidebar for advanced settings
121
  with st.sidebar:
@@ -130,7 +137,7 @@ with st.sidebar:
130
  log_text += message["text"] + "\n\n"
131
  log_text = log_text.rstrip("\n")
132
 
133
- # PythonでURLエンコード
134
  log_text_escaped = urllib.parse.quote(log_text)
135
 
136
  copy_log_button_html = f"""
@@ -149,7 +156,7 @@ with st.sidebar:
149
 
150
  st.header("Advanced Settings")
151
  model = st.selectbox("Model", options=["command-r-plus", "command-r"], index=0)
152
- preamble = st.text_area("Preamble", height=100)
153
  temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.3, step=0.1)
154
  k = st.slider("Top-K", min_value=0, max_value=500, value=0, step=1)
155
  p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
@@ -190,16 +197,35 @@ with st.sidebar:
190
  else:
191
  st.warning("Please enter a valid API Key.")
192
 
 
 
 
 
 
 
 
 
193
  display_messages()
194
 
 
 
 
 
 
 
 
 
 
 
 
195
  if prompt := st.chat_input("What is up?"):
196
  chat_history = st.session_state.messages.copy()
197
-
198
  with st.chat_message("user"):
199
  st.write(prompt)
200
-
201
- response = get_ai_response(prompt, chat_history)
202
-
203
  st.session_state.messages.append({"role": "USER", "text": prompt})
 
 
 
204
  st.session_state.messages.append({"role": "CHATBOT", "text": response})
205
  st.rerun()
 
1
  import cohere
2
  import streamlit as st
3
  from streamlit.components.v1 import html
4
+ from streamlit_extras.stylable_container import stylable_container
5
  import re
6
  import urllib.parse
7
 
 
27
  st.session_state.messages = []
28
 
29
  def get_ai_response(prompt, chat_history):
30
+ st.session_state.is_streaming = True
31
+
32
+ with stylable_container(
33
+ key="stop_generating",
34
+ css_styles="""
35
+ button {
36
+ position: fixed;
37
+ bottom: 100px;
38
+ left: 50%;
39
+ transform: translateX(-50%);
40
+ }
41
+ """,
42
+ ):
43
+ st.button("Stop generating")
44
+
45
  with st.chat_message("ai"):
46
  penalty_kwargs = {
47
  "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
 
57
  p=p,
58
  **penalty_kwargs
59
  )
60
+
61
+ st.session_state.response = ""
62
  placeholder = st.empty()
63
  for event in stream:
64
  if event.event_type == "text-generation":
65
  content = event.text
66
+ st.session_state.response += content
67
+ placeholder.markdown(st.session_state.response)
68
+
69
+ st.session_state.is_streaming = False
70
+ return st.session_state.response
71
 
72
  def display_messages():
73
  for i, message in enumerate(st.session_state.messages):
 
112
 
113
  if "edit_index" in st.session_state and st.session_state.edit_index == i:
114
  with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
115
+ new_content = st.text_area("Edit message", height=200, value=st.session_state.messages[i]["text"])
116
  col1, col2 = st.columns([1, 1])
117
  with col1:
118
  if st.form_submit_button("Save"):
 
123
  if st.form_submit_button("Cancel"):
124
  del st.session_state.edit_index
125
  st.rerun()
 
 
 
 
 
 
 
 
 
 
126
 
127
  # Add sidebar for advanced settings
128
  with st.sidebar:
 
137
  log_text += message["text"] + "\n\n"
138
  log_text = log_text.rstrip("\n")
139
 
140
+ # Encode the string to escape
141
  log_text_escaped = urllib.parse.quote(log_text)
142
 
143
  copy_log_button_html = f"""
 
156
 
157
  st.header("Advanced Settings")
158
  model = st.selectbox("Model", options=["command-r-plus", "command-r"], index=0)
159
+ preamble = st.text_area("Preamble", height=200)
160
  temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.3, step=0.1)
161
  k = st.slider("Top-K", min_value=0, max_value=500, value=0, step=1)
162
  p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
 
197
  else:
198
  st.warning("Please enter a valid API Key.")
199
 
200
+ # After Stop generating
201
+ if "is_streaming" in st.session_state and st.session_state.is_streaming:
202
+ st.session_state.messages.append({"role": "CHATBOT", "text": st.session_state.response})
203
+ st.session_state.is_streaming = False
204
+ if "retry_flag" in st.session_state and st.session_state.retry_flag:
205
+ st.session_state.retry_flag = False
206
+ st.rerun()
207
+
208
  display_messages()
209
 
210
+ # After Retry
211
+ if "retry_flag" in st.session_state and st.session_state.retry_flag == True:
212
+ if len(st.session_state.messages) > 0:
213
+ prompt = st.session_state.messages[-1]["text"]
214
+ response = get_ai_response(prompt, st.session_state.messages[:-1])
215
+ st.session_state.messages.append({"role": "CHATBOT", "text": response})
216
+ st.session_state.retry_flag = False
217
+ st.rerun()
218
+ else:
219
+ st.session_state.retry_flag = False
220
+
221
  if prompt := st.chat_input("What is up?"):
222
  chat_history = st.session_state.messages.copy()
223
+
224
  with st.chat_message("user"):
225
  st.write(prompt)
 
 
 
226
  st.session_state.messages.append({"role": "USER", "text": prompt})
227
+
228
+ response = get_ai_response(prompt, chat_history)
229
+
230
  st.session_state.messages.append({"role": "CHATBOT", "text": response})
231
  st.rerun()
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- cohere
 
 
1
+ cohere
2
+ streamlit-extras