Spaces:
Paused
Paused
File size: 1,862 Bytes
6ab28e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
from langchain.chains import RetrievalQA
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.vectorstores import Qdrant
from openai.error import InvalidRequestError
from qdrant_client import QdrantClient
from config import get_db_config
PERSIST_DIR_NAME = "nvdajp-book"
def get_retrieval_qa() -> RetrievalQA:
embeddings = OpenAIEmbeddings()
db_url, db_api_key, db_collection_name = get_db_config()
client = QdrantClient(url=db_url, api_key=db_api_key)
db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
retriever = db.as_retriever()
return RetrievalQA.from_chain_type(
llm=OpenAI(temperature=0), chain_type="stuff", retriever=retriever, return_source_documents=True,
)
def _remove_prefix_path(p: str):
prefix = "data/rtdocs/nvdajp-book.readthedocs.io/"
return p.removeprefix(prefix)
def get_related_url(metadata):
path = set()
url = "https://nvdajp-book.readthedocs.io/"
for m in metadata:
p = m['source']
pathname = _remove_prefix_path(p)
if pathname in path:
continue
path.add(pathname)
yield f'<p>url: <a href="{url}{pathname}">{pathname}</a></p>'
def main(query: str):
qa = get_retrieval_qa()
try:
result = qa(query)
except InvalidRequestError as e:
return "回答が見つかりませんでした。別な質問をしてみてください", str(e)
else:
metadata = [s.metadata for s in result["source_documents"]]
html = "<div>" + "\n".join(get_related_url(metadata)) + "</div>"
return result["result"], html
nvdajp_book_qa = gr.Interface(
fn=main,
inputs=[gr.Textbox(label="query")],
outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
)
nvdajp_book_qa.launch()
|