Files changed (3) hide show
  1. app.py +36 -6
  2. nvda_ug_loader.py +107 -0
  3. store.py +28 -13
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from langchain.chains import RetrievalQA
3
  from langchain.embeddings import OpenAIEmbeddings
4
  from langchain.llms import OpenAI
 
5
  from langchain.vectorstores import Qdrant
6
  from openai.error import InvalidRequestError
7
  from qdrant_client import QdrantClient
@@ -9,16 +10,40 @@ from config import DB_CONFIG
9
 
10
 
11
  PERSIST_DIR_NAME = "nvdajp-book"
 
 
 
12
 
13
 
14
- def get_retrieval_qa() -> RetrievalQA:
15
  embeddings = OpenAIEmbeddings()
16
  db_url, db_api_key, db_collection_name = DB_CONFIG
17
  client = QdrantClient(url=db_url, api_key=db_api_key)
18
  db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
19
- retriever = db.as_retriever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  return RetrievalQA.from_chain_type(
21
- llm=OpenAI(temperature=0), chain_type="stuff", retriever=retriever, return_source_documents=True,
 
 
 
 
 
 
22
  )
23
 
24
 
@@ -35,8 +60,8 @@ def get_related_url(metadata):
35
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
36
 
37
 
38
- def main(query: str):
39
- qa = get_retrieval_qa()
40
  try:
41
  result = qa(query)
42
  except InvalidRequestError as e:
@@ -50,7 +75,12 @@ def main(query: str):
50
 
51
  nvdajp_book_qa = gr.Interface(
52
  fn=main,
53
- inputs=[gr.Textbox(label="query")],
 
 
 
 
 
54
  outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
55
  )
56
 
 
2
  from langchain.chains import RetrievalQA
3
  from langchain.embeddings import OpenAIEmbeddings
4
  from langchain.llms import OpenAI
5
+ from langchain.chat_models import ChatOpenAI
6
  from langchain.vectorstores import Qdrant
7
  from openai.error import InvalidRequestError
8
  from qdrant_client import QdrantClient
 
10
 
11
 
12
  PERSIST_DIR_NAME = "nvdajp-book"
13
+ # MODEL_NAME = "text-davinci-003"
14
+ # MODEL_NAME = "gpt-3.5-turbo"
15
+ # MODEL_NAME = "gpt-4"
16
 
17
 
18
+ def get_retrieval_qa(model_name: str | None, temperature: int, option: str | None) -> RetrievalQA:
19
  embeddings = OpenAIEmbeddings()
20
  db_url, db_api_key, db_collection_name = DB_CONFIG
21
  client = QdrantClient(url=db_url, api_key=db_api_key)
22
  db = Qdrant(client=client, collection_name=db_collection_name, embeddings=embeddings)
23
+ if model_name is None:
24
+ model = "gpt-3.5-turbo"
25
+ elif model_name == "GPT-3.5":
26
+ model = "gpt-3.5-turbo"
27
+ elif model_name == "GPT-4":
28
+ model = "gpt-4"
29
+ else:
30
+ model = "gpt-3.5-turbo"
31
+ if option is None or option == "All":
32
+ retriever = db.as_retriever()
33
+ else:
34
+ retriever = db.as_retriever(
35
+ search_kwargs={
36
+ "filter": {"category": option},
37
+ }
38
+ )
39
  return RetrievalQA.from_chain_type(
40
+ llm=ChatOpenAI(
41
+ model=model,
42
+ temperature=temperature
43
+ ),
44
+ chain_type="stuff",
45
+ retriever=retriever,
46
+ return_source_documents=True,
47
  )
48
 
49
 
 
60
  yield f'<p>URL: <a href="{url}">{url}</a> (category: {category})</p>'
61
 
62
 
63
+ def main(query: str, model_name: str, option: str, temperature: int):
64
+ qa = get_retrieval_qa(model_name, temperature, option)
65
  try:
66
  result = qa(query)
67
  except InvalidRequestError as e:
 
75
 
76
  nvdajp_book_qa = gr.Interface(
77
  fn=main,
78
+ inputs=[
79
+ gr.Textbox(label="query"),
80
+ gr.Radio(["GPT-3.5", "GPT-4"], label="Model", info="選択なしで「3.5」を使用"),
81
+ gr.Radio(["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?"),
82
+ gr.Slider(0, 2)
83
+ ],
84
  outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()],
85
  )
86
 
nvda_ug_loader.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import re
3
+ from typing import Iterator, List
4
+ from langchain.docstore.document import Document
5
+ from langchain.document_loaders.base import BaseLoader
6
+
7
+ from bs4 import BeautifulSoup, Tag, ResultSet
8
+ import requests
9
+
10
+
11
+ RE_HEADERS = re.compile(r"h[23]")
12
+
13
+
14
+ @dataclass
15
+ class Content:
16
+ name: str
17
+ title: str
18
+ text: str
19
+ body: list[Tag]
20
+
21
+
22
+ def _get_anchor_name(header: Tag) -> str:
23
+ for tag in header.previous_elements:
24
+ if tag.name == "a":
25
+ return tag.attrs.get("name", "")
26
+ return ""
27
+
28
+
29
+ def _reversed_remove_last_anchor(body: list[Tag]) -> Iterator[Tag]:
30
+ has_anchor = False
31
+ for tag in reversed(body):
32
+ if not has_anchor:
33
+ if tag.name == "a":
34
+ has_anchor = True
35
+ continue
36
+ else:
37
+ yield tag
38
+
39
+
40
+ def _remove_last_anchor(body: list[Tag]) -> Iterator[Tag]:
41
+ return reversed(list(_reversed_remove_last_anchor(body)))
42
+
43
+
44
+ def _get_bodys_text(body: list[Tag]) -> str:
45
+ text = ""
46
+ for tag in body:
47
+ text += tag.get_text()
48
+ return text
49
+
50
+
51
+ def _get_child_content(header: Tag) -> Content:
52
+ title = header.get_text()
53
+ name = _get_anchor_name(header)
54
+ body = [header]
55
+ for i, child in enumerate(header.next_elements):
56
+ if i == 0:
57
+ continue
58
+ if child.name == "h2" or child.name == "h3":
59
+ break
60
+ body.append(child)
61
+ removed_next_anchor_body = list(_remove_last_anchor(body))
62
+ text = _get_bodys_text(removed_next_anchor_body)
63
+ return Content(name,
64
+ title,
65
+ text,
66
+ removed_next_anchor_body
67
+ )
68
+
69
+
70
+ def get_contents(headers: ResultSet[Tag]) -> Iterator[Content]:
71
+ for header in headers:
72
+ yield _get_child_content(header)
73
+
74
+
75
+ class NVDAUserGuideLoader(BaseLoader):
76
+ """
77
+ """
78
+ def __init__(self, url: str, category: str) -> None:
79
+ self.url = url
80
+ self.category = category
81
+
82
+ def fetch(self) -> BeautifulSoup:
83
+ res = requests.get(self.url)
84
+ soup = BeautifulSoup(res.content, 'lxml')
85
+ return soup
86
+
87
+ def lazy_load(self) -> Iterator[Document]:
88
+ soup = self.fetch()
89
+ # body = soup.body
90
+ headers = soup.find_all(RE_HEADERS)
91
+ for content in get_contents(headers):
92
+ name = content.name
93
+ title = content.title
94
+ text = content.text
95
+ metadata = {"category": self.category, "source": name, "url": f"{self.url}#{name}", "title": title}
96
+ yield Document(page_content=text, metadata=metadata)
97
+
98
+ def load(self) -> List[Document]:
99
+ return list(self.lazy_load())
100
+
101
+
102
+ if __name__ == "__main__":
103
+ url = "https://www.nvaccess.org/files/nvda/documentation/userGuide.html"
104
+ loader = NVDAUserGuideLoader(url, "en-nvda-user-guide")
105
+ data = loader.load()
106
+ print(data)
107
+ # breakpoint()
store.py CHANGED
@@ -3,6 +3,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
  from langchain.embeddings import OpenAIEmbeddings
4
  from langchain.vectorstores import Qdrant
5
  # from qdrant_client import QdrantClient
 
6
  from config import DB_CONFIG
7
 
8
 
@@ -18,14 +19,13 @@ def get_documents(path: str):
18
  loader = ReadTheDocsLoader(path, encoding="utf-8")
19
  docs = loader.load()
20
  base_url = "https://nvdajp-book.readthedocs.io/"
21
- add_meta = {"category": "ja-book"}
22
  for doc in docs:
23
  org_metadata = doc.metadata
24
  source = _remove_prefix_path(org_metadata["source"])
25
- add_meta = {"category": "ja-book", "source": source, "url": f"{base_url}{source}"}
26
  doc.metadata = org_metadata | add_meta
27
  yield doc
28
- # return docs
29
 
30
 
31
  def get_text_chunk(docs):
@@ -47,24 +47,39 @@ def store(texts):
47
  )
48
 
49
 
50
- def main(path: str):
51
  docs = get_documents(path)
52
  texts = get_text_chunk(docs)
53
  store(texts)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
56
  if __name__ == "__main__":
57
  """
58
- $ python store.py "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest"
 
 
59
  """
60
  import sys
61
  args = sys.argv
62
- if len(args) != 2:
63
- print("No args, you need two args for html_path")
64
- docs = get_documents("data/rtdocs/nvdajp-book.readthedocs.io/ja/latest")
65
- print(type(docs))
66
- breakpoint()
67
  else:
68
- path = args[1]
69
- # dir_name = args[2]
70
- main(path)
 
 
 
 
 
 
3
  from langchain.embeddings import OpenAIEmbeddings
4
  from langchain.vectorstores import Qdrant
5
  # from qdrant_client import QdrantClient
6
+ from nvda_ug_loader import NVDAUserGuideLoader
7
  from config import DB_CONFIG
8
 
9
 
 
19
  loader = ReadTheDocsLoader(path, encoding="utf-8")
20
  docs = loader.load()
21
  base_url = "https://nvdajp-book.readthedocs.io/"
22
+ category = "ja-book"
23
  for doc in docs:
24
  org_metadata = doc.metadata
25
  source = _remove_prefix_path(org_metadata["source"])
26
+ add_meta = {"category": category, "source": source, "url": f"{base_url}{source}"}
27
  doc.metadata = org_metadata | add_meta
28
  yield doc
 
29
 
30
 
31
  def get_text_chunk(docs):
 
47
  )
48
 
49
 
50
+ def rtd_main(path: str):
51
  docs = get_documents(path)
52
  texts = get_text_chunk(docs)
53
  store(texts)
54
 
55
 
56
+ def nul_main(url: str):
57
+ if "www.nvda.jp" in url:
58
+ category = "ja-nvda-user-guide"
59
+ else:
60
+ category = "en-nvda-user-guide"
61
+ loader = NVDAUserGuideLoader(url, category)
62
+ docs = loader.load()
63
+ texts = get_text_chunk(docs)
64
+ store(texts)
65
+
66
+
67
  if __name__ == "__main__":
68
  """
69
+ $ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest"
70
+ $ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html"
71
+ $ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html"
72
  """
73
  import sys
74
  args = sys.argv
75
+ if len(args) != 3:
76
+ print("No args, you need two args for type, html_path")
 
 
 
77
  else:
78
+ type_ = args[1]
79
+ path = args[2]
80
+ if type_ == "rtd":
81
+ rtd_main(path)
82
+ elif type_ == "nul":
83
+ nul_main(path)
84
+ else:
85
+ print("No type for store")