itsanurag's picture
Create app.py
9890aee verified
raw
history blame
2.67 kB
import fitz # PyMuPDF
from transformers import DPRQuestionEncoderTokenizer, DPRQuestionEncoder
from transformers import T5Tokenizer, T5ForConditionalGeneration
import json
import faiss
import numpy as np
import streamlit as st
# Function to extract text from PDF
def extract_text_from_pdf(pdf_path):
document = fitz.open(pdf_path)
text = ""
for page_num in range(document.page_count):
page = document.load_page(page_num)
text += page.get_text("text")
return text
# Function to chunk text into smaller segments
def chunk_text(text, chunk_size=1000):
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
# Initialize models
retriever_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
retriever = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
generator_tokenizer = T5Tokenizer.from_pretrained('t5-base')
generator = T5ForConditionalGeneration.from_pretrained('t5-base')
# Index chunks using FAISS
def index_chunks(chunks):
index = faiss.IndexFlatL2(768) # Assuming 768-dimensional embeddings
chunk_embeddings = []
for chunk in chunks:
inputs = retriever_tokenizer(chunk, return_tensors='pt', padding=True, truncation=True)
chunk_embedding = retriever(**inputs).pooler_output.detach().numpy()
chunk_embeddings.append(chunk_embedding)
chunk_embeddings = np.vstack(chunk_embeddings)
index.add(chunk_embeddings)
return index, chunk_embeddings
# Function to get answer to a query
def get_answer(query, chunks, index, chunk_embeddings, max_length=50):
# Encode query using retriever
inputs = retriever_tokenizer(query, return_tensors='pt')
question_embedding = retriever(**inputs).pooler_output.detach().numpy()
# Search for the most relevant chunk
distances, indices = index.search(question_embedding, 1)
retrieved_chunk = chunks[indices[0][0]]
# Generate answer using generator
input_ids = generator_tokenizer(retrieved_chunk, return_tensors='pt').input_ids
output_ids = generator.generate(input_ids, max_length=max_length)
answer = generator_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return answer
# Load and process PDF
pdf_text = extract_text_from_pdf('policy-booklet-0923.pdf')
chunks = chunk_text(pdf_text)
index, chunk_embeddings = index_chunks(chunks)
# Streamlit front-end
st.title("RAG-Powered PDF Chatbot")
user_query = st.text_input("Enter your question:")
if user_query:
answer = get_answer(user_query, chunks, index, chunk_embeddings, max_length=100)
st.write("Answer:", answer)