import gradio as gr from langchain.chains import RetrievalQA from langchain.embeddings import OpenAIEmbeddings from langchain.llms import OpenAI from langchain.vectorstores import Qdrant from openai.error import InvalidRequestError from qdrant_client import QdrantClient from config import DB_CONFIG PERSIST_DIR_NAME = "nvdajp-book" def get_retrieval_qa(temperature: int, option: str) -> RetrievalQA: embeddings = OpenAIEmbeddings() db_url, db_api_key, db_collection_name = DB_CONFIG client = QdrantClient(url=db_url, api_key=db_api_key) db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings) if option is None or option == "All": retriever = db.as_retriever() else: retriever = db.as_retriever( search_kwargs={ "filter": {"category": option}, } ) return RetrievalQA.from_chain_type( llm=OpenAI(temperature=temperature), chain_type="stuff", retriever=retriever, return_source_documents=True, ) def get_related_url(metadata): urls = set() for m in metadata: # p = m['source'] url = m["url"] if url in urls: continue urls.add(url) category = m["category"] # print(m) yield f'

URL: {url} (category: {category})

' def main(query: str, option: str, temperature: int): qa = get_retrieval_qa(temperature, option) try: result = qa(query) except InvalidRequestError as e: return "回答が見つかりませんでした。別な質問をしてみてください", str(e) else: metadata = [s.metadata for s in result["source_documents"]] html = "
" + "\n".join(get_related_url(metadata)) + "
" return result["result"], html nvdajp_book_qa = gr.Interface( fn=main, inputs=[ gr.Textbox(label="query"), gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"), gr.Slider(0, 2) ], outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()], ) nvdajp_book_qa.launch()