nvdajp-book-qa / store.py
terapyon's picture
dev/modify-embedding-test (#4)
2c70642
raw
history blame
3.13 kB
from langchain.document_loaders import ReadTheDocsLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Qdrant
# from qdrant_client import QdrantClient
from nvda_ug_loader import NVDAUserGuideLoader
from config import DB_CONFIG, DB_E5_CONFIG
CHUNK_SIZE = 500
def _remove_prefix_path(p: str):
prefix = "data/rtdocs/nvdajp-book.readthedocs.io/"
return p.removeprefix(prefix)
def get_documents(path: str):
loader = ReadTheDocsLoader(path, encoding="utf-8")
docs = loader.load()
base_url = "https://nvdajp-book.readthedocs.io/"
category = "ja-book"
for doc in docs:
org_metadata = doc.metadata
source = _remove_prefix_path(org_metadata["source"])
add_meta = {
"category": category,
"source": source,
"url": f"{base_url}{source}",
}
doc.metadata = org_metadata | add_meta
yield doc
def get_text_chunk(docs):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=0
)
texts = text_splitter.split_documents(docs)
return texts
def store(texts, mname):
if mname == "openai":
embeddings = OpenAIEmbeddings()
db_url, db_api_key, db_collection_name = DB_CONFIG
elif mname == "e5":
model_name = "intfloat/multilingual-e5-large"
model_kwargs = {"device": "cuda"}
encode_kwargs = {"normalize_embeddings": False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
db_url, db_api_key, db_collection_name = DB_E5_CONFIG
else:
raise ValueError("Invalid mname")
_ = Qdrant.from_documents(
texts,
embeddings,
url=db_url,
api_key=db_api_key,
collection_name=db_collection_name,
)
def rtd_main(path: str, mname: str):
docs = get_documents(path)
texts = get_text_chunk(docs)
store(texts, mname)
def nul_main(url: str, mname: str):
if "www.nvda.jp" in url:
category = "ja-nvda-user-guide"
else:
category = "en-nvda-user-guide"
loader = NVDAUserGuideLoader(url, category)
docs = loader.load()
texts = get_text_chunk(docs)
store(texts, mname)
if __name__ == "__main__":
"""
$ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest" openai
$ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html" e5
$ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html" e5
"""
import sys
args = sys.argv
if len(args) != 4:
print("No args, you need two args for type, html_path")
else:
type_ = args[1]
path = args[2]
mname = args[3]
if type_ == "rtd":
rtd_main(path, mname)
elif type_ == "nul":
nul_main(path, mname)
else:
print("No type for store")