import gradio as gr from langchain.chains import RetrievalQA from langchain.embeddings import OpenAIEmbeddings from langchain.llms import OpenAI from langchain.chat_models import ChatOpenAI from langchain.vectorstores import Qdrant from openai.error import InvalidRequestError from qdrant_client import QdrantClient from config import DB_CONFIG PERSIST_DIR_NAME = "nvdajp-book" # MODEL_NAME = "text-davinci-003" # MODEL_NAME = "gpt-3.5-turbo" # MODEL_NAME = "gpt-4" def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA: embeddings = OpenAIEmbeddings() db_url, db_api_key, db_collection_name = DB_CONFIG client = QdrantClient(url=db_url, api_key=db_api_key) db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings) if model_name is None: model = "gpt-3.5-turbo" elif model_name == "GPT-3.5": model = "gpt-3.5-turbo" elif model_name == "GPT-4": model = "gpt-4" else: model = "gpt-3.5-turbo" if option is None or option == "All": retriever = db.as_retriever() else: retriever = db.as_retriever( search_kwargs={ "filter": {"category": option}, } ) return RetrievalQA.from_chain_type( llm=ChatOpenAI( model=model, temperature=temperature ), chain_type="stuff", retriever=retriever, return_source_documents=True, ) def get_related_url(metadata): urls = set() for m in metadata: # p = m['source'] url = m["url"] if url in urls: continue urls.add(url) category = m["category"] # print(m) yield f'

URL: {url} (category: {category})

' def main(query: str, model_name: str, option: str, temperature: int): qa = get_retrieval_qa(model_name, temperature, option) try: result = qa(query) except InvalidRequestError as e: return "回答が見つかりませんでした。別な質問をしてみてください", str(e) else: metadata = [s.metadata for s in result["source_documents"]] html = "
" + "\n".join(get_related_url(metadata)) + "
" return result["result"], html nvdajp_book_qa = gr.Interface( fn=main, inputs=[ gr.Textbox(label="query"), gr.Radio(["GPT-3.5", "GPT-4"], label="Model", info="選択なしで「3.5」を使用"), gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"), gr.Slider(0, 2) ], outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()], ) nvdajp_book_qa.launch()