shinichi-a commited on
Commit
fa9309e
1 Parent(s): b75c3e3

Add "Search Only" to OpenAI model options and make OpenAI API key input optional

Browse files

In this pull request, the following changes have been made:

1. Added "Search Only" to the options for the OpenAI model. This allows users to perform only a search without using the OpenAI model.

2. Made the input of the OpenAI API key optional. This allows the search to be executed even if the OpenAI API key is not entered.

These changes allow users to perform only a search without using the OpenAI model and to execute a search even if the OpenAI API key is not entered, improving user convenience.

Please review and let me know your thoughts.

Files changed (1) hide show
  1. app.py +22 -36
app.py CHANGED
@@ -1,23 +1,15 @@
1
- """
2
- streamlit run app.py --server.address 0.0.0.0
3
- """
4
-
5
  from __future__ import annotations
6
 
7
- import streamlit as st
8
  import os
9
-
10
- import faiss
11
- from sentence_transformers import SentenceTransformer
12
  import torch
13
- from openai import OpenAI
14
  import streamlit as st
15
- import pandas as pd
16
- import os
17
  from time import time
 
 
 
18
  from datasets.download import DownloadManager
19
- from datasets import load_dataset # type: ignore
20
-
21
 
22
  WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
23
  WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
@@ -36,6 +28,7 @@ EMB_MODEL_NAMES = list(EMB_MODEL_PQ.keys())
36
  OPENAI_MODEL_NAMES = [
37
  "gpt-3.5-turbo-1106",
38
  "gpt-4-1106-preview",
 
39
  ]
40
 
41
  E5_QUERY_TYPES = [
@@ -60,7 +53,6 @@ Responses must be given in Japanese.
60
  {question}
61
  """.strip()
62
 
63
-
64
  if os.getenv("SPACE_ID"):
65
  USE_HF_SPACE = True
66
  os.environ["HF_HOME"] = "/data/.huggingface"
@@ -68,9 +60,7 @@ if os.getenv("SPACE_ID"):
68
  else:
69
  USE_HF_SPACE = False
70
 
71
- # for tokenizer
72
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
73
-
74
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
75
 
76
 
@@ -81,6 +71,7 @@ def get_model(name: str, max_seq_length=512):
81
  device = "cuda"
82
  elif torch.backends.mps.is_available():
83
  device = "mps"
 
84
  model = SentenceTransformer(name, device=device)
85
  model.max_seq_length = max_seq_length
86
  return model
@@ -93,9 +84,7 @@ def get_wikija_ds(name: str = WIKIPEDIA_JS_DS_NAME):
93
 
94
 
95
  @st.cache_resource
96
- def get_faiss_index(
97
- index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME
98
- ):
99
  target_path = f"faiss_indexes/{name}/{index_name}"
100
  dm = DownloadManager()
101
  index_local_path = dm.download(
@@ -110,9 +99,7 @@ def text_to_emb(model, text: str, prefix: str):
110
  return model.encode([prefix + text], normalize_embeddings=True)
111
 
112
 
113
- def search(
114
- faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int
115
- ):
116
  start_time = time()
117
  emb = text_to_emb(emb_model, question, search_text_prefix)
118
  emb_exec_time = time() - start_time
@@ -121,7 +108,7 @@ def search(
121
  scores = scores[0]
122
  indexes = indexes[0]
123
  results = []
124
- for idx, score in zip(indexes, scores): # type: ignore
125
  idx = int(idx)
126
  passage = ds[idx]
127
  results.append((score, passage))
@@ -133,7 +120,6 @@ def to_contexts(passages):
133
  for passage in passages:
134
  title = passage["title"]
135
  text = passage["text"]
136
- # section = passage["section"]
137
  contexts += f"- {title}: {text}\n"
138
  return contexts
139
 
@@ -211,15 +197,13 @@ def app():
211
  key="question",
212
  value="楽曲『約束はいらない』でデビューした、声優は誰?",
213
  )
214
- if not OPENAI_API_KEY:
215
- st.text_input(
216
- "OpenAI API Key",
217
- key="openai_api_key",
218
- type="password",
219
- placeholder="※ OpenAI API Key 未入力時は回答を生成せずに、検索のみ実行します",
220
- )
221
- else:
222
- st.session_state.openai_api_key = OPENAI_API_KEY
223
 
224
  with st.expander("オプション"):
225
  option_cols_main = st.columns(2)
@@ -229,6 +213,8 @@ def app():
229
  st.selectbox(
230
  "OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
231
  )
 
 
232
  emb_model_name = st.session_state.emb_model_name
233
  option_cols_sub = st.columns(2)
234
  with option_cols_sub[0]:
@@ -300,10 +286,10 @@ def app():
300
  st.dataframe(df, hide_index=True)
301
 
302
  openai_api_key = st.session_state.openai_api_key
303
- if openai_api_key:
 
304
  openai_api_key = openai_api_key.strip()
305
  answer_header.subheader("Answer: ")
306
- openai_model_name = st.session_state.openai_model_name
307
  temperature = st.session_state.temperature
308
  qa_prompt = st.session_state.qa_prompt
309
  max_tokens = st.session_state.max_tokens
@@ -320,4 +306,4 @@ def app():
320
 
321
 
322
  if __name__ == "__main__":
323
- app()
 
 
 
 
 
1
  from __future__ import annotations
2
 
 
3
  import os
4
+ import pandas as pd
 
 
5
  import torch
6
+ import faiss
7
  import streamlit as st
 
 
8
  from time import time
9
+ from openai import OpenAI
10
+ from sentence_transformers import SentenceTransformer
11
+ from datasets import load_dataset
12
  from datasets.download import DownloadManager
 
 
13
 
14
  WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
15
  WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
 
28
  OPENAI_MODEL_NAMES = [
29
  "gpt-3.5-turbo-1106",
30
  "gpt-4-1106-preview",
31
+ "Search Only",
32
  ]
33
 
34
  E5_QUERY_TYPES = [
 
53
  {question}
54
  """.strip()
55
 
 
56
  if os.getenv("SPACE_ID"):
57
  USE_HF_SPACE = True
58
  os.environ["HF_HOME"] = "/data/.huggingface"
 
60
  else:
61
  USE_HF_SPACE = False
62
 
 
63
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
64
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
65
 
66
 
 
71
  device = "cuda"
72
  elif torch.backends.mps.is_available():
73
  device = "mps"
74
+
75
  model = SentenceTransformer(name, device=device)
76
  model.max_seq_length = max_seq_length
77
  return model
 
84
 
85
 
86
  @st.cache_resource
87
+ def get_faiss_index(index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME):
 
 
88
  target_path = f"faiss_indexes/{name}/{index_name}"
89
  dm = DownloadManager()
90
  index_local_path = dm.download(
 
99
  return model.encode([prefix + text], normalize_embeddings=True)
100
 
101
 
102
+ def search(faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int):
 
 
103
  start_time = time()
104
  emb = text_to_emb(emb_model, question, search_text_prefix)
105
  emb_exec_time = time() - start_time
 
108
  scores = scores[0]
109
  indexes = indexes[0]
110
  results = []
111
+ for idx, score in zip(indexes, scores):
112
  idx = int(idx)
113
  passage = ds[idx]
114
  results.append((score, passage))
 
120
  for passage in passages:
121
  title = passage["title"]
122
  text = passage["text"]
 
123
  contexts += f"- {title}: {text}\n"
124
  return contexts
125
 
 
197
  key="question",
198
  value="楽曲『約束はいらない』でデビューした、声優は誰?",
199
  )
200
+ st.text_input(
201
+ "OpenAI API Key",
202
+ key="openai_api_key",
203
+ type="password",
204
+ value=OPENAI_API_KEY if OPENAI_API_KEY else "",
205
+ placeholder="※ OpenAI API Key 未入力時は回答を生成せずに、検索のみ実行します",
206
+ )
 
 
207
 
208
  with st.expander("オプション"):
209
  option_cols_main = st.columns(2)
 
213
  st.selectbox(
214
  "OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
215
  )
216
+ if "emb_model_name" not in st.session_state:
217
+ st.session_state.emb_model_name = EMB_MODEL_NAMES[0] # replace with the actual default value you want to use
218
  emb_model_name = st.session_state.emb_model_name
219
  option_cols_sub = st.columns(2)
220
  with option_cols_sub[0]:
 
286
  st.dataframe(df, hide_index=True)
287
 
288
  openai_api_key = st.session_state.openai_api_key
289
+ openai_model_name = st.session_state.openai_model_name
290
+ if openai_api_key and openai_model_name != "Search Only":
291
  openai_api_key = openai_api_key.strip()
292
  answer_header.subheader("Answer: ")
 
293
  temperature = st.session_state.temperature
294
  qa_prompt = st.session_state.qa_prompt
295
  max_tokens = st.session_state.max_tokens
 
306
 
307
 
308
  if __name__ == "__main__":
309
+ app()