File size: 4,401 Bytes
42aece6
4f570b0
 
 
 
 
 
 
 
 
 
 
6028f6f
 
 
bb88ab4
 
4f570b0
 
 
 
 
 
 
 
4d68f80
4f570b0
 
 
4d68f80
4f570b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d68f80
42aece6
02442ac
 
6028f6f
 
38ff411
 
6028f6f
 
42aece6
6028f6f
 
42aece6
6028f6f
 
 
42aece6
 
bb88ab4
4f570b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os 
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.llms import Ollama
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder
import torch
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain_huggingface import HuggingFacePipeline




class AdjustedHuggingFaceEmbeddings(HuggingFaceEmbeddings):
    def __call__(self, input):
        return super().__call__(input)


def create_chain(chains, pdf_doc):
    if pdf_doc is None:
        return 'You must convert or upload a pdf first'
    db = create_vector_db(pdf_doc)
    llm = create_model()
    prompt_search_query = ChatPromptTemplate.from_messages([
        MessagesPlaceholder(
            variable_name="chat_history"),
        ("user", "{input}"),
        ("user",
         "Given the above conversation, generate a search query to look up to get information relevant to the conversation")
    ])
    retriever_chain = create_history_aware_retriever(llm, db.as_retriever(), prompt_search_query)
    prompt_get_answer = ChatPromptTemplate.from_messages([
        ("system", "Answer the user's questions based on the below context:\\n\\n{context}"),
        MessagesPlaceholder(variable_name="chat_history"),
        ("user", "{input}"),
    ])
    combine_docs_chain = create_stuff_documents_chain(llm=llm, prompt=prompt_get_answer)
    chains[0] = create_retrieval_chain(retriever_chain, combine_docs_chain)
    return 'Document has successfully been loaded'


def create_model():
    hf_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
    tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-large-uncased-whole-word-masking-finetuned-squad")
    model = AutoModelForCausalLM.from_pretrained("google-bert/bert-large-uncased-whole-word-masking-finetuned-squad",
                                                device_map='auto',
                                                torch_dtype=torch.float16,
                                                token=hf_api_token
                                                )
    pipe = pipeline("text-generation",
                    model=model,
                    tokenizer=tokenizer,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                    max_new_tokens=1024,
                    do_sample=True,
                    top_k=10,
                    num_return_sequences=1,
                    eos_token_id=tokenizer.eos_token_id)
    llm = HuggingFacePipeline(pipeline=pipe, model_kwargs={'temperature': 0})
    return llm


def create_vector_db(doc):
    document = load_document(doc)
    text = split_document(document)
    embedding = AdjustedHuggingFaceEmbeddings()
    db = Chroma.from_documents(text, embedding)
    return db


def load_document(doc):
    loader = PyMuPDFLoader(doc.name)
    document = loader.load()
    return document


def split_document(doc):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    text = text_splitter.split_documents(doc)
    return text


def save_history(history):
    with open('history.txt', 'w') as file:
        for s in history:
            file.write(f'- {s.content}\n')


def answer_query(chain, query: str, chat_history=None) -> str:
    if chain:
        # run the given chain with the given query and history
        chat_history.append(HumanMessage(content=query))
        response = chain.invoke({
            'chat_history': chat_history,
            'input': query
        })
        answer = response['answer']
        print('RESPONSE: ', answer, '\n\n')
        # add the current question and answer to history
        chat_history.append(AIMessage(content=answer))
        # save chat history to text file
        save_history(chat_history)
        return answer
    else:
        return "Please load a document first."