File size: 3,254 Bytes
6125df0
6ab28e5
 
6125df0
 
 
6ab28e5
9022e07
6ab28e5
 
 
9bc4a6c
6ab28e5
 
 
9022e07
 
 
6ab28e5
 
9022e07
6125df0
 
 
 
 
 
 
 
 
 
9bc4a6c
6ab28e5
 
9022e07
 
 
 
 
 
 
 
99d3f35
 
 
 
 
 
 
 
6125df0
9022e07
 
 
 
 
 
 
6ab28e5
6125df0
6ab28e5
 
 
9bc4a6c
6ab28e5
9bc4a6c
 
 
6ab28e5
9bc4a6c
 
 
 
6ab28e5
 
9022e07
 
6ab28e5
 
 
 
 
 
 
 
 
 
 
 
 
99d3f35
 
9022e07
99d3f35
 
 
6ab28e5
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from time import time
import gradio as gr
from langchain.chains import RetrievalQA
# from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import GPT4AllEmbeddings
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores import Qdrant
from openai.error import InvalidRequestError
from qdrant_client import QdrantClient
from config import DB_CONFIG


PERSIST_DIR_NAME = "nvdajp-book"
# MODEL_NAME = "text-davinci-003"
# MODEL_NAME = "gpt-3.5-turbo"
# MODEL_NAME = "gpt-4"


def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
    # embeddings = OpenAIEmbeddings()
    model_name = "sentence-transformers/all-mpnet-base-v2"
    model_kwargs = {'device': 'cpu'}
    encode_kwargs = {'normalize_embeddings': False}
    embeddings = HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs,
    )
    # embeddings = GPT4AllEmbeddings()
    db_url, db_api_key, db_collection_name = DB_CONFIG
    client = QdrantClient(url=db_url, api_key=db_api_key)
    db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
    if model_name is None:
        model = "gpt-3.5-turbo"
    elif model_name == "GPT-3.5":
        model = "gpt-3.5-turbo"
    elif model_name == "GPT-4":
        model = "gpt-4"
    else:
        model = "gpt-3.5-turbo"
    if option is None or option == "All":
        retriever = db.as_retriever()
    else:
        retriever = db.as_retriever(
            search_kwargs={
                "filter": {"category": option},
            }
        )
    result = RetrievalQA.from_chain_type(
        llm=ChatOpenAI(
            model=model,
            temperature=temperature
        ),
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=True,
    )
    return result


def get_related_url(metadata):
    urls = set()
    for m in metadata:
        # p = m['source']
        url = m["url"]
        if url in urls:
            continue
        urls.add(url)
        category = m["category"]
        # print(m)
        yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'


def main(query: str, model_name: str, option: str, temperature: int):
    qa = get_retrieval_qa(model_name, temperature, option)
    try:
        result = qa(query)
    except InvalidRequestError as e:
        return "回答が見つかりませんでした。別な質問をしてみてください", str(e)
    else:
        metadata = [s.metadata for s in result["source_documents"]]
        html = "<div>" + "\n".join(get_related_url(metadata)) + "</div>"

    return result["result"], html


nvdajp_book_qa = gr.Interface(
    fn=main,
    inputs=[
        gr.Textbox(label="query"),
        gr.Radio(["GPT-3.5", "GPT-4"], label="Model", info="選択なしで「3.5」を使用"),
        gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
        gr.Slider(0, 2)
    ],
    outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
)


nvdajp_book_qa.launch()