terapyon commited on
Commit
ed3c145
1 Parent(s): 60093a0

dev/streamlit-ui (#12)

Browse files

- modify to sdk streamlit (06d784cd04f9d79f73d733703dc18fcff332a3c0)

Files changed (2) hide show
  1. README.md +2 -2
  2. app.py +60 -38
README.md CHANGED
@@ -3,8 +3,8 @@ title: NVDA 日本語版ガイドブックQA
3
  emoji: 👀
4
  colorFrom: green
5
  colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.38.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc0-1.0
 
3
  emoji: 👀
4
  colorFrom: green
5
  colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: 1.25.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc0-1.0
app.py CHANGED
@@ -1,9 +1,13 @@
1
  from time import time
2
- import gradio as gr
 
 
 
3
  from langchain.chains import RetrievalQA
4
  from langchain.embeddings import OpenAIEmbeddings
5
  from langchain.embeddings import HuggingFaceEmbeddings
6
- from langchain.prompts import PromptTemplate
 
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
  from langchain.llms import HuggingFacePipeline
@@ -25,7 +29,7 @@ E5_EMBEDDINGS = HuggingFaceEmbeddings(
25
  encode_kwargs=E5_ENCODE_KWARGS,
26
  )
27
 
28
- if torch.cuda.is_available():
29
  RINNA_MODEL_NAME = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
30
  RINNA_TOKENIZER = AutoTokenizer.from_pretrained(RINNA_MODEL_NAME, use_fast=False)
31
  RINNA_MODEL = AutoModelForCausalLM.from_pretrained(
@@ -86,17 +90,6 @@ def _get_llm_model(
86
  return llm
87
 
88
 
89
- # prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
90
-
91
- # {context}
92
-
93
- # Question: {question}
94
- # Answer in Japanese:"""
95
- # PROMPT = PromptTemplate(
96
- # template=prompt_template, input_variables=["context", "question"]
97
- # )
98
-
99
-
100
  def get_retrieval_qa(
101
  collection_name: str | None,
102
  model_name: str | None,
@@ -122,7 +115,6 @@ def get_retrieval_qa(
122
  llm = _get_llm_model(model_name, temperature)
123
 
124
  # chain_type_kwargs = {"prompt": PROMPT}
125
-
126
  result = RetrievalQA.from_chain_type(
127
  llm=llm,
128
  chain_type="stuff",
@@ -146,11 +138,8 @@ def get_related_url(metadata):
146
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
147
 
148
 
149
- def main(
150
- query: str, collection_name: str, model_name: str, option: str, temperature: float
151
- ):
152
  now = time()
153
- qa = get_retrieval_qa(collection_name, model_name, temperature, option)
154
  try:
155
  result = qa(query)
156
  except InvalidRequestError as e:
@@ -163,29 +152,62 @@ def main(
163
  return result["result"], html
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  AVAILABLE_LLMS = ["GPT-3.5", "GPT-4"]
167
 
168
  if RINNA_MODEL is not None:
169
  AVAILABLE_LLMS.append("rinna")
170
 
171
- nvdajp_book_qa = gr.Interface(
172
- fn=main,
173
- inputs=[
174
- gr.Textbox(label="query"),
175
- gr.Radio(["E5", "OpenAI"], value="E5", label="Embedding"),
176
- gr.Radio(
177
- AVAILABLE_LLMS, value="GPT-3.5", label="Model", info="GPU環境だとrinnaが選択可能"
178
- ),
179
- gr.Radio(
180
- ["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"],
181
- value="All",
182
- label="絞り込み",
183
- info="ドキュメント制限する?",
184
- ),
185
- gr.Slider(0, 2),
186
- ],
187
- outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
188
- )
189
 
 
 
190
 
191
- nvdajp_book_qa.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from time import time
2
+ from typing import Iterable
3
+
4
+ # import gradio as gr
5
+ import streamlit as st
6
  from langchain.chains import RetrievalQA
7
  from langchain.embeddings import OpenAIEmbeddings
8
  from langchain.embeddings import HuggingFaceEmbeddings
9
+
10
+ # from langchain.prompts import PromptTemplate
11
  import torch
12
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
13
  from langchain.llms import HuggingFacePipeline
 
29
  encode_kwargs=E5_ENCODE_KWARGS,
30
  )
31
 
32
+ if False and torch.cuda.is_available(): # TODO: for local debug
33
  RINNA_MODEL_NAME = "rinna/bilingual-gpt-neox-4b-instruction-ppo"
34
  RINNA_TOKENIZER = AutoTokenizer.from_pretrained(RINNA_MODEL_NAME, use_fast=False)
35
  RINNA_MODEL = AutoModelForCausalLM.from_pretrained(
 
90
  return llm
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
93
  def get_retrieval_qa(
94
  collection_name: str | None,
95
  model_name: str | None,
 
115
  llm = _get_llm_model(model_name, temperature)
116
 
117
  # chain_type_kwargs = {"prompt": PROMPT}
 
118
  result = RetrievalQA.from_chain_type(
119
  llm=llm,
120
  chain_type="stuff",
 
138
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
139
 
140
 
141
+ def run_qa(query: str, qa: RetrievalQA) -> tuple[str, str]:
 
 
142
  now = time()
 
143
  try:
144
  result = qa(query)
145
  except InvalidRequestError as e:
 
152
  return result["result"], html
153
 
154
 
155
+ def main(
156
+ query: str,
157
+ collection_name: str | None,
158
+ model_name: str | None,
159
+ option: str | None,
160
+ temperature: float,
161
+ e5_option: list[str],
162
+ ) -> Iterable[tuple[str, tuple[str, str]]]:
163
+ qa = get_retrieval_qa(collection_name, model_name, temperature, option)
164
+ if collection_name == "E5":
165
+ for option in e5_option:
166
+ if option == "No":
167
+ yield "E5 No", run_qa(query, qa)
168
+ elif option == "Query":
169
+ yield "E5 Query", run_qa("query: " + query, qa)
170
+ elif option == "Passage":
171
+ yield "E5 Passage", run_qa("passage: " + query, qa)
172
+ else:
173
+ raise ValueError("Unknow option")
174
+ else:
175
+ yield "OpenAI", run_qa(query, qa)
176
+
177
+
178
  AVAILABLE_LLMS = ["GPT-3.5", "GPT-4"]
179
 
180
  if RINNA_MODEL is not None:
181
  AVAILABLE_LLMS.append("rinna")
182
 
183
+ with st.form("my_form"):
184
+ query = st.text_input(label="query")
185
+ collection_name = st.radio(options=["E5", "OpenAI"], label="Embedding")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ # if collection_name == "E5": # TODO : 選択肢で選べるようにする
188
+ e5_option = st.multiselect("E5 option", ["No", "Query", "Passage"], default="No")
189
 
190
+ model_name = st.radio(
191
+ options=AVAILABLE_LLMS,
192
+ label="Model",
193
+ help="GPU環境だとrinnaが選択可能",
194
+ )
195
+ option = st.radio(
196
+ options=["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"],
197
+ label="絞り込み",
198
+ help="ドキュメント制限する?",
199
+ )
200
+ temperature = st.slider(label="temperature", min_value=0, max_value=2)
201
+
202
+ submitted = st.form_submit_button("Submit")
203
+ if submitted:
204
+ with st.spinner("Searching..."):
205
+ results = main(
206
+ query, collection_name, model_name, option, temperature, e5_option
207
+ )
208
+ for type_, (answer, html) in results:
209
+ with st.container():
210
+ st.header(type_)
211
+ st.write(answer)
212
+ st.markdown(html, unsafe_allow_html=True)
213
+ st.divider()