shinichi-a
commited on
Commit
•
fa9309e
1
Parent(s):
b75c3e3
Add "Search Only" to OpenAI model options and make OpenAI API key input optional
Browse filesIn this pull request, the following changes have been made:
1. Added "Search Only" to the options for the OpenAI model. This allows users to perform only a search without using the OpenAI model.
2. Made the input of the OpenAI API key optional. This allows the search to be executed even if the OpenAI API key is not entered.
These changes allow users to perform only a search without using the OpenAI model and to execute a search even if the OpenAI API key is not entered, improving user convenience.
Please review and let me know your thoughts.
app.py
CHANGED
@@ -1,23 +1,15 @@
|
|
1 |
-
"""
|
2 |
-
streamlit run app.py --server.address 0.0.0.0
|
3 |
-
"""
|
4 |
-
|
5 |
from __future__ import annotations
|
6 |
|
7 |
-
import streamlit as st
|
8 |
import os
|
9 |
-
|
10 |
-
import faiss
|
11 |
-
from sentence_transformers import SentenceTransformer
|
12 |
import torch
|
13 |
-
|
14 |
import streamlit as st
|
15 |
-
import pandas as pd
|
16 |
-
import os
|
17 |
from time import time
|
|
|
|
|
|
|
18 |
from datasets.download import DownloadManager
|
19 |
-
from datasets import load_dataset # type: ignore
|
20 |
-
|
21 |
|
22 |
WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
|
23 |
WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
|
@@ -36,6 +28,7 @@ EMB_MODEL_NAMES = list(EMB_MODEL_PQ.keys())
|
|
36 |
OPENAI_MODEL_NAMES = [
|
37 |
"gpt-3.5-turbo-1106",
|
38 |
"gpt-4-1106-preview",
|
|
|
39 |
]
|
40 |
|
41 |
E5_QUERY_TYPES = [
|
@@ -60,7 +53,6 @@ Responses must be given in Japanese.
|
|
60 |
{question}
|
61 |
""".strip()
|
62 |
|
63 |
-
|
64 |
if os.getenv("SPACE_ID"):
|
65 |
USE_HF_SPACE = True
|
66 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
@@ -68,9 +60,7 @@ if os.getenv("SPACE_ID"):
|
|
68 |
else:
|
69 |
USE_HF_SPACE = False
|
70 |
|
71 |
-
# for tokenizer
|
72 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
73 |
-
|
74 |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
75 |
|
76 |
|
@@ -81,6 +71,7 @@ def get_model(name: str, max_seq_length=512):
|
|
81 |
device = "cuda"
|
82 |
elif torch.backends.mps.is_available():
|
83 |
device = "mps"
|
|
|
84 |
model = SentenceTransformer(name, device=device)
|
85 |
model.max_seq_length = max_seq_length
|
86 |
return model
|
@@ -93,9 +84,7 @@ def get_wikija_ds(name: str = WIKIPEDIA_JS_DS_NAME):
|
|
93 |
|
94 |
|
95 |
@st.cache_resource
|
96 |
-
def get_faiss_index(
|
97 |
-
index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME
|
98 |
-
):
|
99 |
target_path = f"faiss_indexes/{name}/{index_name}"
|
100 |
dm = DownloadManager()
|
101 |
index_local_path = dm.download(
|
@@ -110,9 +99,7 @@ def text_to_emb(model, text: str, prefix: str):
|
|
110 |
return model.encode([prefix + text], normalize_embeddings=True)
|
111 |
|
112 |
|
113 |
-
def search(
|
114 |
-
faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int
|
115 |
-
):
|
116 |
start_time = time()
|
117 |
emb = text_to_emb(emb_model, question, search_text_prefix)
|
118 |
emb_exec_time = time() - start_time
|
@@ -121,7 +108,7 @@ def search(
|
|
121 |
scores = scores[0]
|
122 |
indexes = indexes[0]
|
123 |
results = []
|
124 |
-
for idx, score in zip(indexes, scores):
|
125 |
idx = int(idx)
|
126 |
passage = ds[idx]
|
127 |
results.append((score, passage))
|
@@ -133,7 +120,6 @@ def to_contexts(passages):
|
|
133 |
for passage in passages:
|
134 |
title = passage["title"]
|
135 |
text = passage["text"]
|
136 |
-
# section = passage["section"]
|
137 |
contexts += f"- {title}: {text}\n"
|
138 |
return contexts
|
139 |
|
@@ -211,15 +197,13 @@ def app():
|
|
211 |
key="question",
|
212 |
value="楽曲『約束はいらない』でデビューした、声優は誰?",
|
213 |
)
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
else:
|
222 |
-
st.session_state.openai_api_key = OPENAI_API_KEY
|
223 |
|
224 |
with st.expander("オプション"):
|
225 |
option_cols_main = st.columns(2)
|
@@ -229,6 +213,8 @@ def app():
|
|
229 |
st.selectbox(
|
230 |
"OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
|
231 |
)
|
|
|
|
|
232 |
emb_model_name = st.session_state.emb_model_name
|
233 |
option_cols_sub = st.columns(2)
|
234 |
with option_cols_sub[0]:
|
@@ -300,10 +286,10 @@ def app():
|
|
300 |
st.dataframe(df, hide_index=True)
|
301 |
|
302 |
openai_api_key = st.session_state.openai_api_key
|
303 |
-
|
|
|
304 |
openai_api_key = openai_api_key.strip()
|
305 |
answer_header.subheader("Answer: ")
|
306 |
-
openai_model_name = st.session_state.openai_model_name
|
307 |
temperature = st.session_state.temperature
|
308 |
qa_prompt = st.session_state.qa_prompt
|
309 |
max_tokens = st.session_state.max_tokens
|
@@ -320,4 +306,4 @@ def app():
|
|
320 |
|
321 |
|
322 |
if __name__ == "__main__":
|
323 |
-
app()
|
|
|
|
|
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
3 |
import os
|
4 |
+
import pandas as pd
|
|
|
|
|
5 |
import torch
|
6 |
+
import faiss
|
7 |
import streamlit as st
|
|
|
|
|
8 |
from time import time
|
9 |
+
from openai import OpenAI
|
10 |
+
from sentence_transformers import SentenceTransformer
|
11 |
+
from datasets import load_dataset
|
12 |
from datasets.download import DownloadManager
|
|
|
|
|
13 |
|
14 |
WIKIPEDIA_JA_DS = "singletongue/wikipedia-utils"
|
15 |
WIKIPEDIA_JS_DS_NAME = "passages-c400-jawiki-20230403"
|
|
|
28 |
OPENAI_MODEL_NAMES = [
|
29 |
"gpt-3.5-turbo-1106",
|
30 |
"gpt-4-1106-preview",
|
31 |
+
"Search Only",
|
32 |
]
|
33 |
|
34 |
E5_QUERY_TYPES = [
|
|
|
53 |
{question}
|
54 |
""".strip()
|
55 |
|
|
|
56 |
if os.getenv("SPACE_ID"):
|
57 |
USE_HF_SPACE = True
|
58 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
|
|
60 |
else:
|
61 |
USE_HF_SPACE = False
|
62 |
|
|
|
63 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
64 |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
65 |
|
66 |
|
|
|
71 |
device = "cuda"
|
72 |
elif torch.backends.mps.is_available():
|
73 |
device = "mps"
|
74 |
+
|
75 |
model = SentenceTransformer(name, device=device)
|
76 |
model.max_seq_length = max_seq_length
|
77 |
return model
|
|
|
84 |
|
85 |
|
86 |
@st.cache_resource
|
87 |
+
def get_faiss_index(index_name: str, ja_emb_ds: str = WIKIPEDIA_JA_EMB_DS, name=WIKIPEDIA_JS_DS_NAME):
|
|
|
|
|
88 |
target_path = f"faiss_indexes/{name}/{index_name}"
|
89 |
dm = DownloadManager()
|
90 |
index_local_path = dm.download(
|
|
|
99 |
return model.encode([prefix + text], normalize_embeddings=True)
|
100 |
|
101 |
|
102 |
+
def search(faiss_index, emb_model, ds, question: str, search_text_prefix: str, top_k: int):
|
|
|
|
|
103 |
start_time = time()
|
104 |
emb = text_to_emb(emb_model, question, search_text_prefix)
|
105 |
emb_exec_time = time() - start_time
|
|
|
108 |
scores = scores[0]
|
109 |
indexes = indexes[0]
|
110 |
results = []
|
111 |
+
for idx, score in zip(indexes, scores):
|
112 |
idx = int(idx)
|
113 |
passage = ds[idx]
|
114 |
results.append((score, passage))
|
|
|
120 |
for passage in passages:
|
121 |
title = passage["title"]
|
122 |
text = passage["text"]
|
|
|
123 |
contexts += f"- {title}: {text}\n"
|
124 |
return contexts
|
125 |
|
|
|
197 |
key="question",
|
198 |
value="楽曲『約束はいらない』でデビューした、声優は誰?",
|
199 |
)
|
200 |
+
st.text_input(
|
201 |
+
"OpenAI API Key",
|
202 |
+
key="openai_api_key",
|
203 |
+
type="password",
|
204 |
+
value=OPENAI_API_KEY if OPENAI_API_KEY else "",
|
205 |
+
placeholder="※ OpenAI API Key 未入力時は回答を生成せずに、検索のみ実行します",
|
206 |
+
)
|
|
|
|
|
207 |
|
208 |
with st.expander("オプション"):
|
209 |
option_cols_main = st.columns(2)
|
|
|
213 |
st.selectbox(
|
214 |
"OpenAI Model", OPENAI_MODEL_NAMES, index=0, key="openai_model_name"
|
215 |
)
|
216 |
+
if "emb_model_name" not in st.session_state:
|
217 |
+
st.session_state.emb_model_name = EMB_MODEL_NAMES[0] # replace with the actual default value you want to use
|
218 |
emb_model_name = st.session_state.emb_model_name
|
219 |
option_cols_sub = st.columns(2)
|
220 |
with option_cols_sub[0]:
|
|
|
286 |
st.dataframe(df, hide_index=True)
|
287 |
|
288 |
openai_api_key = st.session_state.openai_api_key
|
289 |
+
openai_model_name = st.session_state.openai_model_name
|
290 |
+
if openai_api_key and openai_model_name != "Search Only":
|
291 |
openai_api_key = openai_api_key.strip()
|
292 |
answer_header.subheader("Answer: ")
|
|
|
293 |
temperature = st.session_state.temperature
|
294 |
qa_prompt = st.session_state.qa_prompt
|
295 |
max_tokens = st.session_state.max_tokens
|
|
|
306 |
|
307 |
|
308 |
if __name__ == "__main__":
|
309 |
+
app()
|