Spaces:
Paused
Paused
File size: 3,254 Bytes
6125df0 6ab28e5 6125df0 6ab28e5 9022e07 6ab28e5 9bc4a6c 6ab28e5 9022e07 6ab28e5 9022e07 6125df0 9bc4a6c 6ab28e5 9022e07 99d3f35 6125df0 9022e07 6ab28e5 6125df0 6ab28e5 9bc4a6c 6ab28e5 9bc4a6c 6ab28e5 9bc4a6c 6ab28e5 9022e07 6ab28e5 99d3f35 9022e07 99d3f35 6ab28e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from time import time
import gradio as gr
from langchain.chains import RetrievalQA
# from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import GPT4AllEmbeddings
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores import Qdrant
from openai.error import InvalidRequestError
from qdrant_client import QdrantClient
from config import DB_CONFIG
PERSIST_DIR_NAME = "nvdajp-book"
# MODEL_NAME = "text-davinci-003"
# MODEL_NAME = "gpt-3.5-turbo"
# MODEL_NAME = "gpt-4"
def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
# embeddings = OpenAIEmbeddings()
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
# embeddings = GPT4AllEmbeddings()
db_url, db_api_key, db_collection_name = DB_CONFIG
client = QdrantClient(url=db_url, api_key=db_api_key)
db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
if model_name is None:
model = "gpt-3.5-turbo"
elif model_name == "GPT-3.5":
model = "gpt-3.5-turbo"
elif model_name == "GPT-4":
model = "gpt-4"
else:
model = "gpt-3.5-turbo"
if option is None or option == "All":
retriever = db.as_retriever()
else:
retriever = db.as_retriever(
search_kwargs={
"filter": {"category": option},
}
)
result = RetrievalQA.from_chain_type(
llm=ChatOpenAI(
model=model,
temperature=temperature
),
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
)
return result
def get_related_url(metadata):
urls = set()
for m in metadata:
# p = m['source']
url = m["url"]
if url in urls:
continue
urls.add(url)
category = m["category"]
# print(m)
yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
def main(query: str, model_name: str, option: str, temperature: int):
qa = get_retrieval_qa(model_name, temperature, option)
try:
result = qa(query)
except InvalidRequestError as e:
return "回答が見つかりませんでした。別な質問をしてみてください", str(e)
else:
metadata = [s.metadata for s in result["source_documents"]]
html = "<div>" + "\n".join(get_related_url(metadata)) + "</div>"
return result["result"], html
nvdajp_book_qa = gr.Interface(
fn=main,
inputs=[
gr.Textbox(label="query"),
gr.Radio(["GPT-3.5", "GPT-4"], label="Model", info="選択なしで「3.5」を使用"),
gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
gr.Slider(0, 2)
],
outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
)
nvdajp_book_qa.launch()
|