Nyanfa commited on
Commit
bbd6eac
·
verified ·
1 Parent(s): 0f74d9c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +158 -45
  2. requirements.txt +4 -1
app.py CHANGED
@@ -4,6 +4,9 @@ from streamlit.components.v1 import html
4
  from streamlit_extras.stylable_container import stylable_container
5
  import re
6
  import urllib.parse
 
 
 
7
 
8
  st.title("Cohere Chat UI")
9
 
@@ -12,7 +15,7 @@ if "api_key" not in st.session_state:
12
  if api_key:
13
  if api_key.isascii():
14
  st.session_state.api_key = api_key
15
- client = cohere.Client(api_key=api_key)
16
  st.rerun()
17
  else:
18
  st.warning("Please enter your API key correctly.")
@@ -21,30 +24,84 @@ if "api_key" not in st.session_state:
21
  st.warning("Please enter your API key to use the app. You can obtain your API key from here: https://dashboard.cohere.com/api-keys")
22
  st.stop()
23
  else:
24
- client = cohere.Client(api_key=st.session_state.api_key)
25
 
26
  if "messages" not in st.session_state:
27
  st.session_state.messages = []
28
 
29
- def get_ai_response(prompt, chat_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  st.session_state.is_streaming = True
31
  st.session_state.response = ""
32
 
33
- with st.chat_message("ai", avatar=st.session_state.assistant_avatar):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  penalty_kwargs = {
35
  "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
36
  }
37
-
38
- stream = client.chat_stream(
39
- message=prompt,
40
- model=model,
41
- preamble=preamble,
42
- chat_history=chat_history,
43
- temperature=temperature,
44
- k=k,
45
- p=p,
46
  **penalty_kwargs
47
- )
 
 
 
 
 
 
 
48
 
49
  placeholder = st.empty()
50
 
@@ -63,10 +120,10 @@ def get_ai_response(prompt, chat_history):
63
  st.button("Stop generating")
64
 
65
  shown_message = ""
66
-
67
- for event in stream:
68
- if event.event_type == "text-generation":
69
- content = event.text
70
  st.session_state.response += content
71
  shown_message += content.replace("\n", " \n")\
72
  .replace("<", "\\<")\
@@ -90,10 +147,9 @@ inline_pattern = r"`([^`\n]+?)`"
90
 
91
  def display_messages():
92
  for i, message in enumerate(st.session_state.messages):
93
- name = "user" if message["role"] == "USER" else "ai"
94
- avatar = st.session_state.user_avatar if message["role"] == "USER" else st.session_state.assistant_avatar
95
- with st.chat_message(name, avatar=avatar):
96
- shown_message = message["text"].replace("\n", " \n")\
97
  .replace("<", "\\<")\
98
  .replace(">", "\\>")
99
  if "```" in shown_message:
@@ -113,7 +169,7 @@ def display_messages():
113
  del st.session_state.messages[i]
114
  st.rerun()
115
  with col3:
116
- text_to_copy = message["text"]
117
  # Encode the string to escape
118
  text_to_copy_escaped = urllib.parse.quote(text_to_copy)
119
 
@@ -131,7 +187,7 @@ def display_messages():
131
  """
132
  html(copy_button_html, height=50)
133
 
134
- if i == len(st.session_state.messages) - 1 and message["role"] == "CHATBOT":
135
  with col4:
136
  if st.button("Retry", key=f"retry_{i}_{len(st.session_state.messages)}"):
137
  if len(st.session_state.messages) >= 2:
@@ -141,11 +197,11 @@ def display_messages():
141
 
142
  if "edit_index" in st.session_state and st.session_state.edit_index == i:
143
  with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
144
- new_content = st.text_area("Edit message", height=200, value=st.session_state.messages[i]["text"])
145
  col1, col2 = st.columns([1, 1])
146
  with col1:
147
  if st.form_submit_button("Save"):
148
- st.session_state.messages[i]["text"] = new_content
149
  del st.session_state.edit_index
150
  st.rerun()
151
  with col2:
@@ -163,12 +219,12 @@ with st.sidebar:
163
  # Copy Conversation History button
164
  log_text = ""
165
  for message in st.session_state.messages:
166
- if message["role"] == "USER":
167
  log_text += "<USER>\n"
168
- log_text += message["text"] + "\n\n"
169
  else:
170
  log_text += "<ASSISTANT>\n"
171
- log_text += message["text"] + "\n\n"
172
  log_text = log_text.rstrip("\n")
173
 
174
  # Encode the string to escape
@@ -178,12 +234,21 @@ with st.sidebar:
178
  <button id="copy-log-btn" style='font-size: 1em; padding: 0.5em;' onclick='copyLog()'>Copy Conversation History</button>
179
 
180
  <script>
 
181
  function copyLog() {{
182
- navigator.clipboard.writeText(decodeURIComponent("{log_text_escaped}"));
183
- let copyBtn = document.getElementById("copy-log-btn");
184
  copyBtn.innerHTML = "Copied!";
185
  setTimeout(function(){{ copyBtn.innerHTML = "Copy Conversation History"; }}, 2000);
186
  }}
 
 
 
 
 
 
 
 
187
  </script>
188
  """
189
  html(copy_log_button_html, height=50)
@@ -208,7 +273,56 @@ with st.sidebar:
208
  p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
209
  penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
210
  penalty_value = st.slider("Penalty Value", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
211
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  st.header("Restore History")
213
  history_input = st.text_area("Paste conversation history:", height=200)
214
  if st.button("Restore History"):
@@ -219,13 +333,13 @@ with st.sidebar:
219
  for message in messages:
220
  if message.strip() in ["<USER>", "<ASSISTANT>"]:
221
  if role and text:
222
- st.session_state.messages.append({"role": role, "text": text.strip()})
223
  text = ""
224
- role = "USER" if message.strip() == "<USER>" else "CHATBOT"
225
  else:
226
  text += message
227
  if role and text:
228
- st.session_state.messages.append({"role": role, "text": text.strip()})
229
  st.rerun()
230
 
231
  st.header("Clear History")
@@ -238,13 +352,14 @@ with st.sidebar:
238
  if st.button("Update API Key"):
239
  if new_api_key and new_api_key.isascii():
240
  st.session_state.api_key = new_api_key
241
- client = cohere.Client(api_key=new_api_key)
242
  st.success("API Key updated successfully!")
243
  else:
244
  st.warning("Please enter a valid API Key.")
245
 
246
  with appearance_tab:
247
  st.header("Font Selection")
 
248
  font_options = {
249
  "Zen Maru Gothic": "Zen Maru Gothic",
250
  "Noto Sans JP": "Noto Sans JP",
@@ -267,7 +382,7 @@ with st.sidebar:
267
 
268
  # After Stop generating
269
  if st.session_state.get("is_streaming"):
270
- st.session_state.messages.append({"role": "CHATBOT", "text": st.session_state.response})
271
  st.session_state.is_streaming = False
272
  if "retry_flag" in st.session_state and st.session_state.retry_flag:
273
  st.session_state.retry_flag = False
@@ -311,16 +426,16 @@ display_messages()
311
  # After Retry
312
  if st.session_state.get("retry_flag"):
313
  if len(st.session_state.messages) > 0:
314
- prompt = st.session_state.messages[-1]["text"]
315
- messages = st.session_state.messages[:-1].copy()
316
- response = get_ai_response(prompt, messages)
317
- st.session_state.messages.append({"role": "CHATBOT", "text": response})
318
  st.session_state.retry_flag = False
319
  st.rerun()
320
  else:
321
  st.session_state.retry_flag = False
322
 
323
  if prompt := st.chat_input("Enter your message here..."):
 
324
  chat_history = st.session_state.messages.copy()
325
 
326
  shown_message = prompt.replace("\n", " \n")\
@@ -330,9 +445,7 @@ if prompt := st.chat_input("Enter your message here..."):
330
  with st.chat_message("user", avatar=st.session_state.user_avatar):
331
  st.write(shown_message)
332
 
333
- st.session_state.messages.append({"role": "USER", "text": prompt})
334
-
335
- response = get_ai_response(prompt, chat_history)
336
 
337
- st.session_state.messages.append({"role": "CHATBOT", "text": response})
338
  st.rerun()
 
4
  from streamlit_extras.stylable_container import stylable_container
5
  import re
6
  import urllib.parse
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+ import numpy as np
9
+ import pypdfium2 as pdfium
10
 
11
  st.title("Cohere Chat UI")
12
 
 
15
  if api_key:
16
  if api_key.isascii():
17
  st.session_state.api_key = api_key
18
+ client = cohere.ClientV2(api_key=api_key)
19
  st.rerun()
20
  else:
21
  st.warning("Please enter your API key correctly.")
 
24
  st.warning("Please enter your API key to use the app. You can obtain your API key from here: https://dashboard.cohere.com/api-keys")
25
  st.stop()
26
  else:
27
+ client = cohere.ClientV2(api_key=st.session_state.api_key)
28
 
29
  if "messages" not in st.session_state:
30
  st.session_state.messages = []
31
 
32
+ if "rag_file_key" not in st.session_state:
33
+ st.session_state.rag_file_key = None
34
+
35
+ if "rag_embedded" not in st.session_state:
36
+ st.session_state.rag_embedded = False
37
+
38
+ text_splitter = RecursiveCharacterTextSplitter(
39
+ chunk_size=512,
40
+ chunk_overlap=50,
41
+ length_function=len,
42
+ is_separator_regex=False,
43
+ )
44
+
45
+ def batch_embed(texts, batch_size=96):
46
+ all_embeddings = []
47
+ for i in range(0, len(texts), batch_size):
48
+ batch = texts[i:i+batch_size]
49
+ response = client.embed(
50
+ texts=batch,
51
+ model=embed_model,
52
+ input_type="search_document",
53
+ embedding_types=['float']
54
+ )
55
+ all_embeddings.extend(response.embeddings.float)
56
+ return all_embeddings
57
+
58
+ def cosine_similarity(a, b):
59
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
60
+
61
+ def get_ai_response(chat_history):
62
  st.session_state.is_streaming = True
63
  st.session_state.response = ""
64
 
65
+ with st.chat_message("assistant", avatar=st.session_state.assistant_avatar):
66
+ # RAG
67
+ if st.session_state.get("rag_chunks") and st.session_state.get("rag_embeddings"):
68
+ chunks = st.session_state.rag_chunks
69
+ embeddings = st.session_state.rag_embeddings
70
+
71
+ vector_database = {i: np.array(embedding) for i, embedding in enumerate(embeddings)}
72
+
73
+ query = chat_history[-1]["content"]
74
+ query_embedding = client.embed(texts=[query], model=embed_model, input_type="search_query", embedding_types=['float']).embeddings.float[0]
75
+
76
+ similarities = [cosine_similarity(query_embedding, chunk) for chunk in embeddings]
77
+ top_indices = np.argsort(similarities)[::-1][:10]
78
+ top_chunks_after_retrieval = [chunks[i] for i in top_indices]
79
+
80
+ rerank_response = client.rerank(query=query, documents=top_chunks_after_retrieval, top_n=3, model=rerank_model)
81
+ top_chunks_after_rerank = [top_chunks_after_retrieval[result.index] for result in rerank_response.results]
82
+ documents = [{"data": {"title": f"chunk {i}", "snippet": chunk}} for i, chunk in enumerate(top_chunks_after_rerank)]
83
+
84
  penalty_kwargs = {
85
  "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
86
  }
87
+
88
+ chat_history.insert(0, {"role": "system", "content": preamble})
89
+
90
+ stream_kwargs = {
91
+ "messages": chat_history,
92
+ "model": model,
93
+ "temperature": temperature,
94
+ "k": k,
95
+ "p": p,
96
  **penalty_kwargs
97
+ }
98
+
99
+ if st.session_state.get("rag_text"):
100
+ stream_kwargs["documents"] = documents
101
+ elif model in ["command-r-08-2024", "command-r-plus-08-2024"]:
102
+ stream_kwargs["safety_mode"] = "OFF"
103
+
104
+ stream = client.chat_stream(**stream_kwargs)
105
 
106
  placeholder = st.empty()
107
 
 
120
  st.button("Stop generating")
121
 
122
  shown_message = ""
123
+
124
+ for chunk in stream:
125
+ if chunk.type == "content-delta":
126
+ content = chunk.delta.message.content.text
127
  st.session_state.response += content
128
  shown_message += content.replace("\n", " \n")\
129
  .replace("<", "\\<")\
 
147
 
148
  def display_messages():
149
  for i, message in enumerate(st.session_state.messages):
150
+ avatar = st.session_state.user_avatar if message["role"] == "user" else st.session_state.assistant_avatar
151
+ with st.chat_message(message["role"], avatar=avatar):
152
+ shown_message = message["content"].replace("\n", " \n")\
 
153
  .replace("<", "\\<")\
154
  .replace(">", "\\>")
155
  if "```" in shown_message:
 
169
  del st.session_state.messages[i]
170
  st.rerun()
171
  with col3:
172
+ text_to_copy = message["content"]
173
  # Encode the string to escape
174
  text_to_copy_escaped = urllib.parse.quote(text_to_copy)
175
 
 
187
  """
188
  html(copy_button_html, height=50)
189
 
190
+ if i == len(st.session_state.messages) - 1 and message["role"] == "assistant":
191
  with col4:
192
  if st.button("Retry", key=f"retry_{i}_{len(st.session_state.messages)}"):
193
  if len(st.session_state.messages) >= 2:
 
197
 
198
  if "edit_index" in st.session_state and st.session_state.edit_index == i:
199
  with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
200
+ new_content = st.text_area("Edit message", height=200, value=st.session_state.messages[i]["content"])
201
  col1, col2 = st.columns([1, 1])
202
  with col1:
203
  if st.form_submit_button("Save"):
204
+ st.session_state.messages[i]["content"] = new_content
205
  del st.session_state.edit_index
206
  st.rerun()
207
  with col2:
 
219
  # Copy Conversation History button
220
  log_text = ""
221
  for message in st.session_state.messages:
222
+ if message["role"] == "user":
223
  log_text += "<USER>\n"
224
+ log_text += message["content"] + "\n\n"
225
  else:
226
  log_text += "<ASSISTANT>\n"
227
+ log_text += message["content"] + "\n\n"
228
  log_text = log_text.rstrip("\n")
229
 
230
  # Encode the string to escape
 
234
  <button id="copy-log-btn" style='font-size: 1em; padding: 0.5em;' onclick='copyLog()'>Copy Conversation History</button>
235
 
236
  <script>
237
+ const log_text_escaped = "{log_text_escaped}";
238
  function copyLog() {{
239
+ navigator.clipboard.writeText(decodeURIComponent(log_text_escaped));
240
+ const copyBtn = document.getElementById("copy-log-btn");
241
  copyBtn.innerHTML = "Copied!";
242
  setTimeout(function(){{ copyBtn.innerHTML = "Copy Conversation History"; }}, 2000);
243
  }}
244
+ window.parent.document.addEventListener('keydown', (e) => {{
245
+ if ( e.code == "Pause" ){{
246
+ window.parent.navigator.clipboard.writeText(decodeURIComponent(log_text_escaped));
247
+ const copyBtn = document.getElementById("copy-log-btn");
248
+ copyBtn.innerHTML = "Copied!";
249
+ setTimeout(function(){{ copyBtn.innerHTML = "Copy Conversation History"; }}, 2000);
250
+ }}
251
+ }} , false);
252
  </script>
253
  """
254
  html(copy_log_button_html, height=50)
 
273
  p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
274
  penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
275
  penalty_value = st.slider("Penalty Value", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
276
+
277
+ st.header("RAG")
278
+ st.markdown("Select the model and encoding before uploading the file.")
279
+ rag_model = st.selectbox("RAG Model", options=["Multilingual", "English"], index=0)
280
+ file_encoding = st.selectbox("Encoding", options=["utf_8", "shift_jis"], index=0)
281
+ st.session_state.rag_file = st.file_uploader("Choose a txt or pdf file", type=["txt", "pdf"], key="rag_file_uploader")
282
+
283
+ if rag_model == "Multilingual":
284
+ embed_model = "embed-multilingual-v3.0"
285
+ rerank_model = "rerank-multilingual-v3.0"
286
+ else:
287
+ embed_model = "embed-english-v3.0"
288
+ rerank_model = "rerank-english-v3.0"
289
+
290
+ if st.session_state.rag_file is not None:
291
+ if st.session_state.rag_file_key != st.session_state.rag_file:
292
+ st.session_state.rag_file_key = st.session_state.rag_file
293
+ st.session_state.rag_embedded = False
294
+ if "rag_text" in st.session_state:
295
+ del st.session_state.rag_text
296
+ if "rag_chunks" in st.session_state:
297
+ del st.session_state.rag_chunks
298
+ if "rag_embeddings" in st.session_state:
299
+ del st.session_state.rag_embeddings
300
+
301
+ if not st.session_state.rag_embedded:
302
+ if st.session_state.rag_file.type == "application/pdf":
303
+ pdf = pdfium.PdfDocument(st.session_state.rag_file)
304
+ st.session_state.rag_text = ""
305
+ for page in pdf:
306
+ textpage = page.get_textpage()
307
+ st.session_state.rag_text += textpage.get_text_range()
308
+ else:
309
+ st.session_state.rag_text = st.session_state.rag_file.read().decode(file_encoding)
310
+ chunks_ = text_splitter.create_documents([st.session_state.rag_text])
311
+ chunks = [c.page_content for c in chunks_]
312
+ embeddings = batch_embed(chunks)
313
+ st.session_state.rag_chunks = chunks
314
+ st.session_state.rag_embeddings = embeddings
315
+ st.session_state.rag_embedded = True
316
+ else:
317
+ st.session_state.rag_file_key = None
318
+ st.session_state.rag_embedded = False
319
+ if "rag_text" in st.session_state:
320
+ del st.session_state.rag_text
321
+ if "rag_chunks" in st.session_state:
322
+ del st.session_state.rag_chunks
323
+ if "rag_embeddings" in st.session_state:
324
+ del st.session_state.rag_embeddings
325
+
326
  st.header("Restore History")
327
  history_input = st.text_area("Paste conversation history:", height=200)
328
  if st.button("Restore History"):
 
333
  for message in messages:
334
  if message.strip() in ["<USER>", "<ASSISTANT>"]:
335
  if role and text:
336
+ st.session_state.messages.append({"role": role, "content": text.strip()})
337
  text = ""
338
+ role = "user" if message.strip() == "<USER>" else "assistant"
339
  else:
340
  text += message
341
  if role and text:
342
+ st.session_state.messages.append({"role": role, "content": text.strip()})
343
  st.rerun()
344
 
345
  st.header("Clear History")
 
352
  if st.button("Update API Key"):
353
  if new_api_key and new_api_key.isascii():
354
  st.session_state.api_key = new_api_key
355
+ client = cohere.ClientV2(api_key=new_api_key)
356
  st.success("API Key updated successfully!")
357
  else:
358
  st.warning("Please enter a valid API Key.")
359
 
360
  with appearance_tab:
361
  st.header("Font Selection")
362
+
363
  font_options = {
364
  "Zen Maru Gothic": "Zen Maru Gothic",
365
  "Noto Sans JP": "Noto Sans JP",
 
382
 
383
  # After Stop generating
384
  if st.session_state.get("is_streaming"):
385
+ st.session_state.messages.append({"role": "assistant", "content": st.session_state.response})
386
  st.session_state.is_streaming = False
387
  if "retry_flag" in st.session_state and st.session_state.retry_flag:
388
  st.session_state.retry_flag = False
 
426
  # After Retry
427
  if st.session_state.get("retry_flag"):
428
  if len(st.session_state.messages) > 0:
429
+ messages = st.session_state.messages.copy()
430
+ response = get_ai_response(messages)
431
+ st.session_state.messages.append({"role": "assistant", "content": response})
 
432
  st.session_state.retry_flag = False
433
  st.rerun()
434
  else:
435
  st.session_state.retry_flag = False
436
 
437
  if prompt := st.chat_input("Enter your message here..."):
438
+ st.session_state.messages.append({"role": "user", "content": prompt})
439
  chat_history = st.session_state.messages.copy()
440
 
441
  shown_message = prompt.replace("\n", " \n")\
 
445
  with st.chat_message("user", avatar=st.session_state.user_avatar):
446
  st.write(shown_message)
447
 
448
+ response = get_ai_response(chat_history)
 
 
449
 
450
+ st.session_state.messages.append({"role": "assistant", "content": response})
451
  st.rerun()
requirements.txt CHANGED
@@ -1,2 +1,5 @@
1
  cohere
2
- streamlit-extras
 
 
 
 
1
  cohere
2
+ streamlit-extras
3
+ numpy
4
+ langchain-text-splitters
5
+ pypdfium2