File size: 5,571 Bytes
51187cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import time
from langchain_community.llms import HuggingFacePipeline
from langchain.storage import LocalFileStore
from langchain.embeddings import CacheBackedEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import ConversationalRetrievalChain
import os
from pypdf import PdfMerger
from argparse import ArgumentParser


mod = "google/gemma-2b"
tsk = "text-generation"

def merge_pdfs(pdfs: list):
    merger = PdfMerger()
    for pdf in pdfs:
        merger.append(pdf)
    merger.write(f"{pdfs[-1].split('.')[0]}_results.pdf")
    merger.close()
    return f"{pdfs[-1].split('.')[0]}_results.pdf"

def create_a_persistent_db(pdfpath, dbpath, cachepath) -> None:
    """

    Creates a persistent database from a PDF file.



    Args:

        pdfpath (str): The path to the PDF file.

        dbpath (str): The path to the storage folder for the persistent LocalDB.

        cachepath (str): The path to the storage folder for the embeddings cache.

    """
    print("Started the operation...")
    a = time.time()
    loader = PyPDFLoader(pdfpath)
    documents = loader.load()

    ### Split the documents into smaller chunks for processing
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    texts = text_splitter.split_documents(documents)

    ### Use HuggingFace embeddings for transforming text into numerical vectors
    ### This operation can take a while the first time but, once you created your local database with
    ### cached embeddings, it should be a matter of seconds to load them!
    embeddings = HuggingFaceEmbeddings()
    store = LocalFileStore(
        os.path.join(
            cachepath, os.path.basename(pdfpath).split(".")[0] + "_cache"
        )
    )
    cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
        underlying_embeddings=embeddings,
        document_embedding_cache=store,
        namespace=os.path.basename(pdfpath).split(".")[0],
    )

    b = time.time()
    print(
        f"Embeddings successfully created and stored at {os.path.join(cachepath, os.path.basename(pdfpath).split('.')[0]+'_cache')} under namespace: {os.path.basename(pdfpath).split('.')[0]}"
    )
    print(f"To load and embed, it took: {b - a}")

    persist_directory = os.path.join(
        dbpath, os.path.basename(pdfpath).split(".")[0] + "_localDB"
    )
    vectordb = Chroma.from_documents(
        documents=texts,
        embedding=cached_embeddings,
        persist_directory=persist_directory,
    )
    c = time.time()
    print(
        f"Persistent database successfully created and stored at {os.path.join(dbpath, os.path.basename(pdfpath).split('.')[0] + '_localDB')}"
    )
    print(f"To create a persistent database, it took: {c - b}")
    return vectordb

def convert_none_to_str(l: list):
    newlist = []
    for i in range(len(l)):
        if l[i] is None or type(l[i])==tuple:
            newlist.append("")
        else:
            newlist.append(l[i])
    return tuple(newlist)

def just_chatting(

    task,

    model,

    tokenizer,

    query,

    vectordb,

    chat_history=[]

):
    """

    Implements a chat system using Hugging Face models and a persistent database.



    Args:

        task (str): Task for the pipeline; for now supported task are ['text-generation', 'text2text-generation']

        model (AutoModelForCausalLM): Hugging Face model, already loaded and prepared.

        tokenizer (AutoTokenizer): Hugging Face tokenizer, already loaded and prepared.

        model_task (str): Task for the Hugging Face model.

        persistent_db_dir (str): Directory for the persistent database.

        embeddings_cache (str): Path to cache Hugging Face embeddings.

        pdfpath (str): Path to the PDF file.

        query (str): Question by the user

        vectordb (ChromaDB): vectorstorer variable for retrieval.

        chat_history (list): A list with previous questions and answers, serves as context; by default it is empty (it may make the model allucinate)

    """
    ### Create a text-generation pipeline and connect it to a ConversationalRetrievalChain
    pipe = pipeline(task,
                    model=model,
                    tokenizer=tokenizer,
                    max_new_tokens = 2048,
                    repetition_penalty = float(10),
    )

    local_llm = HuggingFacePipeline(pipeline=pipe)
    llm_chain = ConversationalRetrievalChain.from_llm(
        llm=local_llm,
        chain_type="stuff",
        retriever=vectordb.as_retriever(search_kwargs={"k": 1}),
        return_source_documents=False,
    )
    rst = llm_chain({"question": query, "chat_history": chat_history})
    return rst


try:
    tokenizer = AutoTokenizer.from_pretrained(
        mod,
    )

    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        mod,
    )
except Exception as e:
    import sys
    print(f"The error {e} occured while handling model and tokenizer loading: please ensure that the model you provided was correct and suitable for the specified task. Be also sure that the HF repository for the loaded model contains all the necessary files.", file=sys.stderr)
    sys.exit(1)