|
import streamlit as st |
|
import torch |
|
import fitz |
|
import os |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
HF_TOKEN = os.environ["HF_TOKEN"] |
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
with st.sidebar: |
|
st.text(f"π₯οΈ Using device: {DEVICE}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", token=HF_TOKEN) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"microsoft/phi-2", |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
).to(DEVICE) |
|
|
|
|
|
embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
embedder = embedder.to(DEVICE) |
|
|
|
st.title("π RAG App using π€ Phi-2") |
|
|
|
|
|
with st.sidebar: |
|
st.header("π Document Upload") |
|
uploaded_file = st.file_uploader("π Upload a PDF or TXT file", type=["pdf", "txt"]) |
|
|
|
|
|
|
|
def extract_text(file): |
|
text = "" |
|
if file.type == "application/pdf": |
|
doc = fitz.open(stream=file.read(), filetype="pdf") |
|
for page in doc: |
|
text += page.get_text() |
|
elif file.type == "text/plain": |
|
text = file.read().decode("utf-8") |
|
return text |
|
|
|
|
|
|
|
def split_into_chunks(text, chunk_size=500): |
|
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] |
|
|
|
|
|
|
|
def create_faiss_index(chunks): |
|
embeddings = embedder.encode(chunks, convert_to_tensor=True, device=DEVICE) |
|
embeddings_np = embeddings.cpu().detach().numpy() |
|
dim = embeddings_np.shape[1] |
|
index = faiss.IndexFlatL2(dim) |
|
index.add(embeddings_np) |
|
return index, embeddings_np |
|
|
|
|
|
|
|
def retrieve_chunks(query, chunks, index, k=5): |
|
query_embedding = embedder.encode([query], convert_to_tensor=True, device=DEVICE) |
|
query_embedding_np = query_embedding.cpu().detach().numpy() |
|
D, I = index.search(query_embedding_np, k) |
|
return [chunks[i] for i in I[0]] |
|
|
|
|
|
|
|
if uploaded_file: |
|
with st.sidebar: |
|
st.success("β
File uploaded successfully!") |
|
raw_text = extract_text(uploaded_file) |
|
chunks = split_into_chunks(raw_text) |
|
st.info(f"π Document split into {len(chunks)} chunks") |
|
|
|
|
|
with st.expander("π View Document Text"): |
|
st.text_area("Extracted text", raw_text, height=200) |
|
|
|
index, embeddings = create_faiss_index(chunks) |
|
|
|
|
|
st.markdown("### π¬ Chat with your Document") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Ask something about the document"): |
|
|
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
with st.spinner("Thinking..."): |
|
context = "\n".join(retrieve_chunks(prompt, chunks, index)) |
|
|
|
|
|
full_prompt = ( |
|
f"Instruction: Answer the following question using only the context provided. " |
|
f"Extract specific information directly from the context when available. " |
|
f"If the answer is not in the context, respond with 'Information not found.'\n\n" |
|
f"Context:\n{context}\n\n" |
|
f"Question: {prompt}\n\n" |
|
f"Answer: " |
|
) |
|
|
|
input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids, |
|
max_new_tokens=256, |
|
num_return_sequences=1, |
|
temperature=0.2, |
|
do_sample=True, |
|
top_p=0.9, |
|
) |
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if "Answer:" in generated_text: |
|
answer = generated_text.split("Answer:")[-1].strip() |
|
else: |
|
answer = generated_text.replace(full_prompt, "").strip() |
|
|
|
st.markdown(answer) |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": answer}) |
|
|
|
else: |
|
|
|
st.info("π Please upload a document using the sidebar to start chatting!") |