File size: 3,535 Bytes
4f570b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.llms import Ollama
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder



class AdjustedHuggingFaceEmbeddings(HuggingFaceEmbeddings):
    def __call__(self, input):
        return super().__call__(input)


def create_chain(chains, pdf_doc, use_local_model=True):
    if pdf_doc is None:
        return 'You must convert or upload a pdf first'
    db = create_vector_db(pdf_doc)
    llm = create_model(use_local_model)
    prompt_search_query = ChatPromptTemplate.from_messages([
        MessagesPlaceholder(
            variable_name="chat_history"),
        ("user", "{input}"),
        ("user",
         "Given the above conversation, generate a search query to look up to get information relevant to the conversation")
    ])
    retriever_chain = create_history_aware_retriever(llm, db.as_retriever(), prompt_search_query)
    prompt_get_answer = ChatPromptTemplate.from_messages([
        ("system", "Answer the user's questions based on the below context:\\n\\n{context}"),
        MessagesPlaceholder(variable_name="chat_history"),
        ("user", "{input}"),
    ])
    combine_docs_chain = create_stuff_documents_chain(llm=llm, prompt=prompt_get_answer)
    chains[0] = create_retrieval_chain(retriever_chain, combine_docs_chain)
    return 'Document has successfully been loaded'


def create_model(local: bool):
    if local:
        llm = Ollama(model='phi')
    else:
        llm = HuggingFaceEndpoint(
            repo_id="OpenAssistant/oasst-sft-1-pythia-12b",
            model_kwargs={"max_length": 256},
            temperature=1.0
        )
    return llm


def create_vector_db(doc):
    document = load_document(doc)
    text = split_document(document)
    embedding = AdjustedHuggingFaceEmbeddings()
    db = Chroma.from_documents(text, embedding)
    return db


def load_document(doc):
    loader = PyMuPDFLoader(doc.name)
    document = loader.load()
    return document


def split_document(doc):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    text = text_splitter.split_documents(doc)
    return text


def save_history(history):
    with open('history.txt', 'w') as file:
        for s in history:
            file.write(f'- {s.content}\n')


def answer_query(chain, query: str, chat_history=None) -> str:
    if chain:
        # run the given chain with the given query and history
        chat_history.append(HumanMessage(content=query))
        response = chain.invoke({
            'chat_history': chat_history,
            'input': query
        })
        answer = response['answer']
        print('RESPONSE: ', answer, '\n\n')
        # add the current question and answer to history
        chat_history.append(AIMessage(content=answer))
        # save chat history to text file
        save_history(chat_history)
        return answer
    else:
        return "Please load a document first."