File size: 3,173 Bytes
7009660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# %%
import nltk
from langchain.indexes import VectorstoreIndexCreator
from langchain.text_splitter import CharacterTextSplitter, NLTKTextSplitter
from langchain.document_loaders import OnlinePDFLoader
from langchain.vectorstores import Chroma
from langchain.embeddings import LlamaCppEmbeddings, HuggingFaceInstructEmbeddings
from chromadb.config import Settings
import chromadb
from chromadb.utils import embedding_functions
from hashlib import sha256
import cloudpickle
import logging
import os
from load_model import load_embedding
import torch
import re
import pathlib

current_path = str( pathlib.Path(__file__).parent.resolve() )

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
nltk.download('punkt')

persist_directory = current_path + "/VectorStore"
logger = logging.getLogger()


# %%

def create_collection(collection_name, model_name, client):
    """Not used atm"""
    if not torch.cuda.is_available():
        device= "cpu"
    else:
        device= "cuda"
    ef = embedding_functions.InstructorEmbeddingFunction(
        model_name=model_name, device=device)
    client.get_or_create_collection(collection_name, embedding_function=ef)
    return True

def create_and_add(collection_name, sub_docs, model_name):
    client_settings = chromadb.config.Settings(
        chroma_db_impl="duckdb+parquet",
        persist_directory=persist_directory,
        anonymized_telemetry=False
    )

    client = chromadb.Client(client_settings)
    collection_name = collection_name + "_" + re.sub('[^A-Za-z0-9]+', '', model_name)

    embeddings = load_embedding(model_name) 
    logging.info(f"Adding documents to {collection_name}")
    vectorstore = Chroma(
        collection_name=collection_name,
        embedding_function=embeddings,
        client_settings=client_settings,
        persist_directory=persist_directory,
    )
    vectorstore.add_documents(documents=sub_docs, embedding=embeddings)
    vectorstore.persist()

    # Test Vectorstore
    vectorstore2 = Chroma(
    collection_name=collection_name,
    embedding_function=embeddings,
    client_settings=client_settings,
    persist_directory=persist_directory,
    )
    print( vectorstore2.similarity_search_with_score(query="What are AXAs green Goals?", k=4) )

    return vectorstore

def load_from_web(urls, cache=True):
    docs_list = urls
    filename=f"./{sha256(str(urls).encode('utf-8')).hexdigest()}.pkl"

    isFile = os.path.isfile(filename)

    if cache and isFile:
        logger.info("Using Cache")
        pikd = open(filename, "rb")
        docs = cloudpickle.load(pikd)
    else:
        loaders=[OnlinePDFLoader(pdf) for pdf in docs_list]
        docs = []
        for loader in loaders:
            docs.extend(loader.load())
        with open(filename, 'wb') as output:
            cloudpickle.dump(docs, output)

    #update metadata
    i=0
    for doc in docs:
        doc.metadata = {'source': docs_list[i], 'url': docs_list[i], 'company':'AXA'}
        i=i+1
    return docs
        
def load_and_split(docs, chunk_size=700):
    text_splitter = NLTKTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
    sub_docs = text_splitter.split_documents(docs)
    return sub_docs