File size: 2,962 Bytes
5090140
28ed44f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import json
import gradio as gr
from tempfile import NamedTemporaryFile

from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceHub
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.runnables import RunnableParallel, RunnablePassthrough

huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")

def load_and_split_document(file):
    """Loads and splits the document into pages."""
    loader = PyPDFLoader(file.name)
    data = loader.load_and_split()
    return data

def get_embeddings():
    return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

def create_database(data, embeddings):
    db = FAISS.from_documents(data, embeddings)
    db.save_local("faiss_database")

prompt = """
Answer the question based only on the following context:
{context}
Question: {question}
"""

def get_model():
    return HuggingFaceHub(
        repo_id="mistralai/Mistral-7B-Instruct-v0.3",
        model_kwargs={"temperature": 0.5, "max_length": 512},
        huggingfacehub_api_token=huggingface_token
    )

def response(database, model, question):
    prompt_val = ChatPromptTemplate.from_template(prompt)
    retriever = database.as_retriever()
    parser = StrOutputParser()
    chain = (
        {'context': retriever, 'question': RunnablePassthrough()}
        | prompt_val
        | model
        | parser
    )
    ans = chain.invoke(question)
    return ans

def update_vectors(file):
    if file is None:
        return "Please upload a PDF file."
    data = load_and_split_document(file)
    embed = get_embeddings()
    create_database(data, embed)
    return "Vector store updated successfully."

def ask_question(question):
    if not question:
        return "Please enter a question."
    embed = get_embeddings()
    database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
    model = get_model()
    return response(database, model, question)

with gr.Blocks() as demo:
    gr.Markdown("# Chat with your PDF documents")
    
    with gr.Row():
        file_input = gr.File(label="Upload your PDF document", file_types=[".pdf"])
        update_button = gr.Button("Update Vector Store")
    
    update_output = gr.Textbox(label="Update Status")
    update_button.click(update_vectors, inputs=[file_input], outputs=update_output)
    
    with gr.Row():
        question_input = gr.Textbox(label="Ask a question about your documents")
        submit_button = gr.Button("Submit")
    
    answer_output = gr.Textbox(label="Answer")
    submit_button.click(ask_question, inputs=[question_input], outputs=answer_output)

if __name__ == "__main__":
    demo.launch()