RAG / app.py
jk12p's picture
modified UI
ead42f3 verified
import streamlit as st
import torch
import fitz # PyMuPDF
import os
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
# --- CONFIG ---
HF_TOKEN = os.environ["HF_TOKEN"]
# Your Hugging Face token
# Check if CUDA (GPU) is available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move device info to sidebar
with st.sidebar:
st.text(f"πŸ–₯️ Using device: {DEVICE}")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(DEVICE)
# Load sentence transformer and move to device
embedder = SentenceTransformer("all-MiniLM-L6-v2")
embedder = embedder.to(DEVICE)
st.title("πŸ” RAG App using πŸ€– Phi-2")
# Move file upload and processing to sidebar
with st.sidebar:
st.header("πŸ“ Document Upload")
uploaded_file = st.file_uploader("πŸ“„ Upload a PDF or TXT file", type=["pdf", "txt"])
# Extract text from file
def extract_text(file):
text = ""
if file.type == "application/pdf":
doc = fitz.open(stream=file.read(), filetype="pdf")
for page in doc:
text += page.get_text()
elif file.type == "text/plain":
text = file.read().decode("utf-8")
return text
# Split into chunks
def split_into_chunks(text, chunk_size=500):
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
# Create FAISS index
def create_faiss_index(chunks):
embeddings = embedder.encode(chunks, convert_to_tensor=True, device=DEVICE)
embeddings_np = embeddings.cpu().detach().numpy()
dim = embeddings_np.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings_np)
return index, embeddings_np
# Retrieve top-k chunks
def retrieve_chunks(query, chunks, index, k=5):
query_embedding = embedder.encode([query], convert_to_tensor=True, device=DEVICE)
query_embedding_np = query_embedding.cpu().detach().numpy()
D, I = index.search(query_embedding_np, k)
return [chunks[i] for i in I[0]]
# --- MAIN LOGIC ---
if uploaded_file:
with st.sidebar:
st.success("βœ… File uploaded successfully!")
raw_text = extract_text(uploaded_file)
chunks = split_into_chunks(raw_text)
st.info(f"πŸ“š Document split into {len(chunks)} chunks")
# Add expander for debugging text in sidebar
with st.expander("πŸ“„ View Document Text"):
st.text_area("Extracted text", raw_text, height=200)
index, embeddings = create_faiss_index(chunks)
# Main chat interface
st.markdown("### πŸ’¬ Chat with your Document")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input
if prompt := st.chat_input("Ask something about the document"):
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Generate response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
context = "\n".join(retrieve_chunks(prompt, chunks, index))
# Updated prompt for Phi-2's instruction style
full_prompt = (
f"Instruction: Answer the following question using only the context provided. "
f"Extract specific information directly from the context when available. "
f"If the answer is not in the context, respond with 'Information not found.'\n\n"
f"Context:\n{context}\n\n"
f"Question: {prompt}\n\n"
f"Answer: "
)
input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model.generate(
input_ids,
max_new_tokens=256,
num_return_sequences=1,
temperature=0.2,
do_sample=True,
top_p=0.9,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the answer part
if "Answer:" in generated_text:
answer = generated_text.split("Answer:")[-1].strip()
else:
answer = generated_text.replace(full_prompt, "").strip()
st.markdown(answer)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": answer})
else:
# Display message when no file is uploaded
st.info("πŸ‘ˆ Please upload a document using the sidebar to start chatting!")