Spaces:
Running
Running
import streamlit as st | |
from google.api_core.client_options import ClientOptions | |
from google.cloud import documentai_v1 | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import textwrap | |
import os | |
import json | |
import tempfile | |
import os | |
import requests | |
# ------------------- Secure Credential Loading for Hugging Face ------------------- # | |
# This section loads the Service Account from Hugging Face Secrets for ADC | |
# 1. Load the Service Account JSON string from the environment variable (secret) | |
gcp_credentials_json_str = os.getenv("GCP_CREDENTIALS_JSON") | |
project_id = "wise-env-461717-t5" # Initialize project_id | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
# 2. Check if the secret is present | |
if gcp_credentials_json_str: | |
try: | |
# --- FIX: Write to the /tmp/ directory, which is writable on Hugging Face Spaces --- | |
credentials_file_path = "/tmp/gcp_service_account.json" | |
# 3. Write the JSON string to the file in the temporary directory | |
with open(credentials_file_path, "w") as f: | |
f.write(gcp_credentials_json_str) | |
# 4. Set the environment variable to point to this file | |
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credentials_file_path | |
# Extract project_id from the credentials for convenience | |
creds_dict = json.loads(gcp_credentials_json_str) | |
project_id = creds_dict.get("project_id") | |
except Exception as e: | |
st.error(f"π¨ Failed to process GCP credentials: {e}") | |
st.stop() | |
else: | |
st.error("π¨ GCP_CREDENTIALS_JSON secret not found! Please add it to your Hugging Face Space settings.") | |
st.stop() | |
# ------------------- Configuration ------------------- # | |
# Project ID is now dynamically loaded from the service account | |
if not project_id: | |
st.error("π¨ Project ID could not be found in the GCP credentials.") | |
st.stop() | |
# You still need to provide your Processor ID and location | |
processor_id = "86a7eec52bbb9616" # <-- REPLACE WITH YOUR PROCESSOR ID | |
location = "us" # e.g., "us" or "eu" | |
# ------------------- Google Document AI Client (Uses ADC) ------------------- # | |
# The client now automatically finds and uses the credentials set above | |
try: | |
opts = ClientOptions(api_endpoint=f"{location}-documentai.googleapis.com") | |
docai_client = documentai_v1.DocumentProcessorServiceClient(client_options=opts) | |
full_processor_name = docai_client.processor_path(project_id, location, processor_id) | |
except Exception as e: | |
st.error(f"Error initializing Document AI client: {e}") | |
st.stop() | |
def load_embedding_model(): | |
# Use a writable cache directory | |
cache_dir = "/tmp/hf_cache" | |
os.makedirs(cache_dir, exist_ok=True) | |
# Set Hugging Face environment variables | |
os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
os.environ["HF_HOME"] = cache_dir | |
# Load embedding model | |
return SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir) | |
embed_model = load_embedding_model() | |
# ------------------- Utility Functions ------------------- # | |
def chunk_text(text, max_chars=500): | |
return textwrap.wrap(text, max_chars) | |
def extract_text_with_documentai(file_path): | |
with open(file_path, "rb") as f: | |
content = f.read() | |
raw_document = documentai_v1.RawDocument(content=content, mime_type="application/pdf") | |
request = documentai_v1.ProcessRequest(name=full_processor_name, raw_document=raw_document) | |
result = docai_client.process_document(request=request) | |
document = result.document | |
return document.text | |
def build_index(text): | |
text_chunks = chunk_text(text) | |
embeddings = embed_model.encode(text_chunks) | |
dim = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dim) | |
index.add(np.array(embeddings)) | |
return index, text_chunks | |
def retrieve_context(query, index, text_chunks, top_k=5): | |
query_embed = embed_model.encode([query]) | |
distances, indices = index.search(np.array(query_embed), top_k) | |
return [text_chunks[i] for i in indices[0]] | |
# ------------------- Gemini API Functions ------------------- # | |
def ask_groq_agent(query, context): | |
prompt = f"Context:\n{context}\n\nQuestion: {query}\nAnswer:" | |
response = requests.post( | |
"https://api.groq.com/openai/v1/chat/completions", | |
headers={"Authorization": f"Bearer {GROQ_API_KEY}"}, | |
json={ | |
"model": "llama3-70b-8192", | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.3 | |
} | |
) | |
return response.json()["choices"][0]["message"]["content"] | |
def get_summary(text): | |
prompt = f"Please provide a concise summary of the following document:\n\n{text[:4000]}" | |
response = requests.post( | |
"https://api.groq.com/openai/v1/chat/completions", | |
headers={"Authorization": f"Bearer {GROQ_API_KEY}"}, | |
json={ | |
"model": "llama3-70b-8192", | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.3 | |
} | |
) | |
return response.json()["choices"][0]["message"]["content"] | |
def generate_flashcards(text_chunks): | |
joined_text = "\n".join(text_chunks) | |
prompt = ( | |
"Generate 5 helpful flashcards from the following content. " | |
"Use the format exactly like this:\n\n" | |
"Q: What is ...?\nA: ...\n\nQ: How does ...?\nA: ...\n\n" | |
"Text:\n" + joined_text | |
) | |
response = requests.post( | |
"https://api.groq.com/openai/v1/chat/completions", | |
headers={"Authorization": f"Bearer {GROQ_API_KEY}"}, | |
json={ | |
"model": "llama3-70b-8192", | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.5 | |
} | |
) | |
content = response.json()["choices"][0]["message"]["content"] | |
flashcards = [] | |
question = None | |
for line in content.strip().splitlines(): | |
line = line.strip() | |
if line.lower().startswith("q:"): | |
question = line[2:].strip() | |
elif line.lower().startswith("a:") and question: | |
answer = line[2:].strip() | |
flashcards.append({"question": question, "answer": answer}) | |
question = None | |
return flashcards | |
st.title("π PDF AI Assistant (Groq + DocAI)") | |
if "index" not in st.session_state: | |
st.session_state.index = None | |
st.session_state.text_chunks = [] | |
st.session_state.raw_text = "" | |
with st.sidebar: | |
st.header("π€ Upload PDF") | |
uploaded_file = st.file_uploader("Choose a PDF file", type="pdf") | |
if uploaded_file is not None: | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: | |
tmp_file.write(uploaded_file.read()) | |
tmp_file.flush() | |
tmp_path = tmp_file.name | |
# DEBUG: File info | |
st.write("Saved file at:", tmp_path) | |
st.write("File size:", os.path.getsize(tmp_path), "bytes") | |
st.write("File exists:", os.path.exists(tmp_path)) | |
with st.spinner("Extracting text using Document AI..."): | |
raw_text = extract_text_with_documentai(tmp_path) | |
index, text_chunks = build_index(raw_text) | |
st.session_state.index = index | |
st.session_state.text_chunks = text_chunks | |
st.session_state.raw_text = raw_text | |
st.success("β Document processed successfully.") | |
except Exception as e: | |
st.error(f"Error: {e}") | |
finally: | |
os.unlink(tmp_path) | |
# ------------------- Q&A Interface ------------------- # | |
st.subheader("β Ask Questions") | |
if st.session_state.index: | |
question = st.text_input("Enter your question") | |
if st.button("Ask"): | |
context = "\n\n".join(retrieve_context(question, st.session_state.index, st.session_state.text_chunks)) | |
answer = ask_groq_agent(question, context) | |
st.markdown(f"**Answer:** {answer}") | |
else: | |
st.info("Upload a PDF to start asking questions.") | |
# ------------------- Summary Interface ------------------- # | |
st.subheader("π Document Summary") | |
if st.session_state.text_chunks: | |
if st.button("Generate Summary"): | |
with st.spinner("Generating summary..."): | |
summary = get_summary(" ".join(st.session_state.text_chunks)) | |
st.markdown(summary) | |
else: | |
st.info("Upload a PDF to get a summary.") | |
# ------------------- Flashcards ------------------- # | |
st.subheader("π§ Flashcards") | |
if st.session_state.text_chunks: | |
if st.button("Generate Flashcards"): | |
with st.spinner("Generating flashcards..."): | |
flashcards = generate_flashcards(st.session_state.text_chunks) | |
for fc in flashcards: | |
st.markdown(f"**Q: {fc['question']}**\n\nA: {fc['answer']}") | |
else: | |
st.info("Upload a PDF to generate flashcards.") | |