File size: 3,131 Bytes
6ab28e5
 
 
2c70642
6ab28e5
2c70642
6ab28e5
99d3f35
2c70642
6ab28e5
 
 
 
 
9bc4a6c
 
 
 
 
6ab28e5
 
 
9bc4a6c
99d3f35
9bc4a6c
 
 
2c70642
 
 
 
 
9bc4a6c
 
6ab28e5
 
 
2c70642
 
 
6ab28e5
 
 
 
2c70642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ab28e5
 
 
 
 
2c70642
6ab28e5
 
 
2c70642
6ab28e5
 
2c70642
6ab28e5
 
2c70642
99d3f35
 
 
 
 
 
 
2c70642
99d3f35
 
6ab28e5
 
2c70642
 
 
6ab28e5
 
2c70642
6ab28e5
2c70642
99d3f35
6ab28e5
99d3f35
 
2c70642
99d3f35
2c70642
99d3f35
2c70642
99d3f35
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from langchain.document_loaders import ReadTheDocsLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Qdrant

# from qdrant_client import QdrantClient
from nvda_ug_loader import NVDAUserGuideLoader
from config import DB_CONFIG, DB_E5_CONFIG


CHUNK_SIZE = 500


def _remove_prefix_path(p: str):
    prefix = "data/rtdocs/nvdajp-book.readthedocs.io/"
    return p.removeprefix(prefix)


def get_documents(path: str):
    loader = ReadTheDocsLoader(path, encoding="utf-8")
    docs = loader.load()
    base_url = "https://nvdajp-book.readthedocs.io/"
    category = "ja-book"
    for doc in docs:
        org_metadata = doc.metadata
        source = _remove_prefix_path(org_metadata["source"])
        add_meta = {
            "category": category,
            "source": source,
            "url": f"{base_url}{source}",
        }
        doc.metadata = org_metadata | add_meta
        yield doc


def get_text_chunk(docs):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=CHUNK_SIZE, chunk_overlap=0
    )
    texts = text_splitter.split_documents(docs)
    return texts


def store(texts, mname):
    if mname == "openai":
        embeddings = OpenAIEmbeddings()
        db_url, db_api_key, db_collection_name = DB_CONFIG
    elif mname == "e5":
        model_name = "intfloat/multilingual-e5-large"
        model_kwargs = {"device": "cuda"}
        encode_kwargs = {"normalize_embeddings": False}
        embeddings = HuggingFaceEmbeddings(
            model_name=model_name,
            model_kwargs=model_kwargs,
            encode_kwargs=encode_kwargs,
        )
        db_url, db_api_key, db_collection_name = DB_E5_CONFIG
    else:
        raise ValueError("Invalid mname")
    _ = Qdrant.from_documents(
        texts,
        embeddings,
        url=db_url,
        api_key=db_api_key,
        collection_name=db_collection_name,
    )


def rtd_main(path: str, mname: str):
    docs = get_documents(path)
    texts = get_text_chunk(docs)
    store(texts, mname)


def nul_main(url: str, mname: str):
    if "www.nvda.jp" in url:
        category = "ja-nvda-user-guide"
    else:
        category = "en-nvda-user-guide"
    loader = NVDAUserGuideLoader(url, category)
    docs = loader.load()
    texts = get_text_chunk(docs)
    store(texts, mname)


if __name__ == "__main__":
    """
    $ python store.py rtd "data/rtdocs/nvdajp-book.readthedocs.io/ja/latest" openai
    $ python store.py nul "https://www.nvaccess.org/files/nvda/documentation/userGuide.html" e5
    $ python store.py nul "https://www.nvda.jp/nvda2023.1jp/ja/userGuide.html" e5
    """
    import sys

    args = sys.argv
    if len(args) != 4:
        print("No args, you need two args for type, html_path")
    else:
        type_ = args[1]
        path = args[2]
        mname = args[3]
        if type_ == "rtd":
            rtd_main(path, mname)
        elif type_ == "nul":
            nul_main(path, mname)
        else:
            print("No type for store")