File size: 3,441 Bytes
39bff62
 
 
84e38cb
39bff62
84e38cb
 
 
 
de24c66
39bff62
 
de24c66
84e38cb
39bff62
 
 
 
 
 
 
84e38cb
39bff62
 
 
de24c66
84e38cb
 
39bff62
de24c66
84e38cb
 
39bff62
de24c66
84e38cb
 
 
 
39bff62
de24c66
84e38cb
39bff62
84e38cb
 
39bff62
84e38cb
 
 
 
 
 
de24c66
84e38cb
 
 
 
 
 
 
 
39bff62
84e38cb
 
39bff62
 
 
 
 
 
 
 
de24c66
 
 
39bff62
84e38cb
de24c66
39bff62
 
84e38cb
 
 
 
 
 
 
de24c66
84e38cb
39bff62
84e38cb
 
 
 
 
39bff62
84e38cb
de24c66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate
from llama_index.llms.huggingface import HuggingFaceInferenceAPI
from dotenv import load_dotenv
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings
import os
import base64

# Load environment variables
load_dotenv()

# Configure LLM and Embedding settings
Settings.llm = HuggingFaceInferenceAPI(
    model_name="google/gemma-1.1-7b-it",
    tokenizer_name="google/gemma-1.1-7b-it",
    context_window=3000,
    token=os.getenv("HF_TOKEN"),
    max_new_tokens=512,
    generate_kwargs={"temperature": 0.1},
)
Settings.embed_model = HuggingFaceEmbedding(
    model_name="BAAI/bge-small-en-v1.5"
)

# Define directory paths
PERSIST_DIR = "./db"
DATA_DIR = "data"

# Create directories if they don't exist
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(PERSIST_DIR, exist_ok=True)

def display_pdf(file):
    with open(file, "rb") as f:
        base64_pdf = base64.b64encode(f.read()).decode('utf-8')
    pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>'
    st.markdown(pdf_display, unsafe_allow_html=True)

def ingest_data():
    documents = SimpleDirectoryReader(DATA_DIR).load_data()
    storage_context = StorageContext.from_defaults()
    index = VectorStoreIndex.from_documents(documents)
    index.storage_context.persist(persist_dir=PERSIST_DIR)

def handle_query(query):
    storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
    index = load_index_from_storage(storage_context)
    chat_text_qa_msgs = [
    (
        "user",
        """You are a Q&A chatbot created by Prateek Mohan. Your main goal is to provide accurate answers based on the given context. If a question is outside the scope of the document, kindly advise the user to ask within the context.
        Context:
        {context_str}
        Question:
        {query_str}
        """
    )
    ]
    text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
    
    query_engine = index.as_query_engine(text_qa_template=text_qa_template)
    answer = query_engine.query(query)
    
    if hasattr(answer, 'response'):
        return answer.response
    elif isinstance(answer, dict) and 'response' in answer:
        return answer['response']
    else:
        return "Sorry, I couldn't find an answer."

# Streamlit app
st.title("Talk to your  PDF")
st.markdown("by Prateek Mohan (https://github.com/prtkmhn/)")

if 'messages' not in st.session_state:
    st.session_state.messages = [{'role': 'system', "content": 'Chat to PDF'}]

with st.sidebar:
    st.title("Menu:")
    uploaded_file = st.file_uploader("Upload and Click Submit")
    if st.button("Submit & Process"):
        with st.spinner("Processing..."):
            filepath = "data/saved_pdf.pdf"
            with open(filepath, "wb") as f:
                f.write(uploaded_file.getbuffer())
            ingest_data() 
            st.success("Done")

user_prompt = st.chat_input("Query")
if user_prompt:
    st.session_state.messages.append({'role': 'user', "content": user_prompt})
    response = handle_query(user_prompt)
    st.session_state.messages.append({'role': 'assistant', "content": response})

for message in st.session_state.messages:
    st.write(message['content'])