nvdajp-book-qa / app.py
terapyon's picture
can select model for GPT-4
33db183
raw history blame
No virus
2.77 kB
import gradio as gr
from langchain.chains import RetrievalQA
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores import Qdrant
from openai.error import InvalidRequestError
from qdrant_client import QdrantClient
from config import DB_CONFIG
PERSIST_DIR_NAME = "nvdajp-book"
# MODEL_NAME = "text-davinci-003"
# MODEL_NAME = "gpt-3.5-turbo"
# MODEL_NAME = "gpt-4"
def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
embeddings = OpenAIEmbeddings()
db_url, db_api_key, db_collection_name = DB_CONFIG
client = QdrantClient(url=db_url, api_key=db_api_key)
db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
if model_name is None:
model = "gpt-3.5-turbo"
elif model_name == "GPT-3.5":
model = "gpt-3.5-turbo"
elif model_name == "GPT-4":
model = "gpt-4"
else:
model = "gpt-3.5-turbo"
if option is None or option == "All":
retriever = db.as_retriever()
else:
retriever = db.as_retriever(
search_kwargs={
"filter": {"category": option},
}
)
return RetrievalQA.from_chain_type(
llm=ChatOpenAI(
model=model,
temperature=temperature
),
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
)
def get_related_url(metadata):
urls = set()
for m in metadata:
# p = m['source']
url = m["url"]
if url in urls:
continue
urls.add(url)
category = m["category"]
# print(m)
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
def main(query: str, model_name: str, option: str, temperature: int):
qa = get_retrieval_qa(model_name, temperature, option)
try:
result = qa(query)
except InvalidRequestError as e:
return "回答が見つかりませんでした。別な質問をしてみてください", str(e)
else:
metadata = [s.metadata for s in result["source_documents"]]
html = "<div>" + "\n".join(get_related_url(metadata)) + "</div>"
return result["result"], html
nvdajp_book_qa = gr.Interface(
fn=main,
inputs=[
gr.Textbox(label="query"),
gr.Radio(["GPT-3.5", "GPT-4"], label="Model", info="選択なしで「3.5」を使用"),
gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
gr.Slider(0, 2)
],
outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
)
nvdajp_book_qa.launch()