math-ai-system / src /streamlit_app.py
Hebaelsayed's picture
Update src/streamlit_app.py
7a50bf5 verified
import streamlit as st
import os
import time
import base64
import hashlib
from io import BytesIO
from PIL import Image
import PyPDF2
from pdf2image import convert_from_path
from anthropic import Anthropic
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download, list_repo_files, upload_file
from pathlib import Path
import shutil
import json
import re
# ============================================================================
# PRODUCTION MATH AI SYSTEM v5.0 - FINAL
# ============================================================================
st.set_page_config(
page_title="Math AI System",
page_icon="πŸŽ“",
layout="wide"
)
COLLECTION_NAME = "math_knowledge_base"
DATASET_REPO = "Hebaelsayed/math-ai-documents" # ← CHANGE THIS!
CACHE_DIR = Path("/tmp/hf_cache")
CACHE_DIR.mkdir(exist_ok=True)
OCR_CACHE_FOLDER = "ocr_cache/"
OCR_SCANNED_FOLDER = "OCR_SCANNED/" # New folder for scanned documents
# ============================================================================
# EMBEDDING MODELS (ALL FREE)
# ============================================================================
EMBEDDING_MODELS = {
"MiniLM-L6 (Fast, 384D)": {
"name": "sentence-transformers/all-MiniLM-L6-v2",
"dimensions": 384,
"speed": "Very Fast",
"quality": "Good",
"cost": "FREE"
},
"MiniLM-L12 (Balanced, 384D)": {
"name": "sentence-transformers/all-MiniLM-L12-v2",
"dimensions": 384,
"speed": "Fast",
"quality": "Better",
"cost": "FREE"
},
"MPNet (Best, 768D)": {
"name": "sentence-transformers/all-mpnet-base-v2",
"dimensions": 768,
"speed": "Medium",
"quality": "Excellent",
"cost": "FREE"
},
"BGE-Large (SOTA, 1024D)": {
"name": "BAAI/bge-large-en-v1.5",
"dimensions": 1024,
"speed": "Slower",
"quality": "State-of-the-art",
"cost": "FREE"
}
}
# ============================================================================
# CHUNKING STRATEGIES (ALL FREE)
# ============================================================================
CHUNKING_STRATEGIES = {
"Fixed-Size": {
"description": "Simple word-based chunking with overlap",
"pros": "Fast, predictable",
"cons": "May split equations",
"cost": "FREE"
},
"Semantic": {
"description": "Split at paragraph boundaries",
"pros": "Preserves meaning",
"cons": "Variable sizes",
"cost": "FREE"
},
"Exercise-Aware": {
"description": "Split by Exercise 1, 2, etc.",
"pros": "Perfect for exams",
"cons": "Structured docs only",
"cost": "FREE"
},
"LaTeX-Aware": {
"description": "Preserve math expressions",
"pros": "Never splits equations",
"cons": "More complex",
"cost": "FREE"
}
}
# ============================================================================
# PUBLIC DATASETS (FIXED & WORKING)
# ============================================================================
PUBLIC_DATASETS = {
"Linear Algebra": {
"description": "Matrices, Eigenvalues, SVD",
"datasets": [
{
"name": "Hendrycks MATH (Precalculus)",
"hf_path": "EleutherAI/hendrycks_math",
"subset": "precalculus",
"problem_field": "problem",
"solution_field": "solution",
"relevance": "⭐⭐⭐⭐⭐",
"topics": "Matrices, eigenvalues"
},
{
"name": "DART-Math-Hard",
"hf_path": "hkust-nlp/dart-math-hard",
"subset": None,
"problem_field": "query",
"solution_field": "response",
"relevance": "⭐⭐⭐⭐",
"topics": "Difficult reasoning"
},
{
"name": "MATH-Hard",
"hf_path": "lighteval/MATH-Hard",
"subset": None,
"problem_field": "problem",
"solution_field": "solution",
"relevance": "⭐⭐⭐⭐",
"topics": "Competition math"
}
]
},
"Optimization & Calculus": {
"description": "Gradients, Convex Optimization",
"datasets": [
{
"name": "Hendrycks MATH (Algebra)",
"hf_path": "EleutherAI/hendrycks_math",
"subset": "intermediate_algebra",
"problem_field": "problem",
"solution_field": "solution",
"relevance": "⭐⭐⭐⭐⭐",
"topics": "Calculus, optimization"
}
]
}
}
# ============================================================================
# SESSION STATE
# ============================================================================
if 'embedding_model' not in st.session_state:
st.session_state.embedding_model = EMBEDDING_MODELS["MiniLM-L6 (Fast, 384D)"]["name"]
if 'chunking_strategy' not in st.session_state:
st.session_state.chunking_strategy = "Fixed-Size"
if 'chunk_size' not in st.session_state:
st.session_state.chunk_size = 150
if 'chunk_overlap' not in st.session_state:
st.session_state.chunk_overlap = 30
# ============================================================================
# CACHED RESOURCES
# ============================================================================
@st.cache_resource
def get_qdrant_client():
return QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY")
)
@st.cache_resource
def get_claude_client():
return Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
@st.cache_resource
def get_embedding_model(model_name):
"""Lazy load embedding model only when needed"""
with st.spinner(f"Loading embedding model: {model_name}..."):
return SentenceTransformer(model_name)
# ============================================================================
# CACHE FUNCTIONS
# ============================================================================
def get_cache_path(file_path):
file_hash = hashlib.md5(file_path.encode()).hexdigest()
return CACHE_DIR / f"{file_hash}.pdf"
def is_file_cached(file_path):
return get_cache_path(file_path).exists()
def download_with_cache(file_path):
cache_path = get_cache_path(file_path)
if cache_path.exists():
return str(cache_path), True
try:
hf_token = os.getenv("HF_TOKEN")
downloaded_path = hf_hub_download(
repo_id=DATASET_REPO,
filename=file_path,
repo_type="dataset",
token=hf_token
)
shutil.copy(downloaded_path, cache_path)
return str(cache_path), False
except Exception as e:
st.error(f"Download error: {e}")
return None, False
def clear_pdf_cache():
if CACHE_DIR.exists():
shutil.rmtree(CACHE_DIR)
CACHE_DIR.mkdir(exist_ok=True)
return True
def get_ocr_filename(pdf_name):
return f"{OCR_CACHE_FOLDER}{pdf_name}.json"
def is_ocr_cached_hf(pdf_name):
try:
hf_token = os.getenv("HF_TOKEN")
all_files = list_repo_files(
repo_id=DATASET_REPO,
repo_type="dataset",
token=hf_token
)
return get_ocr_filename(pdf_name) in all_files
except:
return False
def save_ocr_to_hf(pdf_name, transcribed_text, total_tokens):
try:
hf_token = os.getenv("HF_TOKEN")
cache_data = {
"file_name": pdf_name,
"transcribed_text": transcribed_text,
"total_tokens": total_tokens,
"timestamp": time.time(),
"cost": total_tokens * 0.000003
}
temp_path = Path("/tmp") / f"{pdf_name}_ocr.json"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
upload_file(
path_or_fileobj=str(temp_path),
path_in_repo=get_ocr_filename(pdf_name),
repo_id=DATASET_REPO,
repo_type="dataset",
token=hf_token
)
temp_path.unlink()
return True
except Exception as e:
st.warning(f"OCR save failed: {e}")
return False
def load_ocr_from_hf(pdf_name):
try:
hf_token = os.getenv("HF_TOKEN")
ocr_path = hf_hub_download(
repo_id=DATASET_REPO,
filename=get_ocr_filename(pdf_name),
repo_type="dataset",
token=hf_token
)
with open(ocr_path, 'r', encoding='utf-8') as f:
cache_data = json.load(f)
return cache_data.get('transcribed_text'), cache_data.get('total_tokens', 0)
except:
return None, 0
def get_ocr_cache_stats_hf():
try:
hf_token = os.getenv("HF_TOKEN")
all_files = list_repo_files(
repo_id=DATASET_REPO,
repo_type="dataset",
token=hf_token
)
ocr_files = [f for f in all_files if f.startswith(OCR_CACHE_FOLDER)]
total_cost = 0.0
for ocr_file in ocr_files:
try:
ocr_path = hf_hub_download(
repo_id=DATASET_REPO,
filename=ocr_file,
repo_type="dataset",
token=hf_token
)
with open(ocr_path, 'r') as f:
data = json.load(f)
total_cost += data.get('cost', 0)
except:
pass
return len(ocr_files), total_cost
except:
return 0, 0.0
# ============================================================================
# OCR_SCANNED FOLDER FUNCTIONS
# ============================================================================
def list_ocr_scanned_files():
"""List all files in OCR_SCANNED folder"""
try:
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
return []
all_files = list_repo_files(
repo_id=DATASET_REPO,
repo_type="dataset",
token=hf_token
)
# Get all files in OCR_SCANNED, not just .txt (for compatibility)
ocr_files = [f for f in all_files if f.startswith(OCR_SCANNED_FOLDER) and not f.endswith('/')]
return ocr_files
except Exception as e:
# Don't show error at startup, only when actively used
return []
def get_ocr_scanned_filename(pdf_name):
"""Generate OCR_SCANNED filename from PDF name"""
base_name = pdf_name.replace('.pdf', '')
return f"{OCR_SCANNED_FOLDER}{base_name}.txt"
def is_pdf_ocr_scanned(pdf_name):
"""Check if PDF has corresponding OCR_SCANNED file"""
try:
hf_token = os.getenv("HF_TOKEN")
all_files = list_repo_files(
repo_id=DATASET_REPO,
repo_type="dataset",
token=hf_token
)
return get_ocr_scanned_filename(pdf_name) in all_files
except Exception as e:
st.error(f"Error checking OCR status for {pdf_name}: {e}")
return False
def save_ocr_scanned_to_hf(pdf_name, transcribed_text):
"""Save OCR transcription to OCR_SCANNED folder as .txt"""
try:
hf_token = os.getenv("HF_TOKEN")
# Create temporary txt file
temp_path = Path("/tmp") / f"{pdf_name}_ocr.txt"
with open(temp_path, 'w', encoding='utf-8') as f:
f.write(transcribed_text)
# Upload to HuggingFace
upload_file(
path_or_fileobj=str(temp_path),
path_in_repo=get_ocr_scanned_filename(pdf_name),
repo_id=DATASET_REPO,
repo_type="dataset",
token=hf_token
)
temp_path.unlink()
return True
except Exception as e:
st.warning(f"OCR save to OCR_SCANNED failed: {e}")
return False
def load_ocr_scanned_from_hf(filename):
"""Load OCR transcription from OCR_SCANNED folder"""
try:
hf_token = os.getenv("HF_TOKEN")
# If filename doesn't include folder path, add it
if not filename.startswith(OCR_SCANNED_FOLDER):
if filename.endswith('.pdf'):
filename = get_ocr_scanned_filename(filename)
else:
filename = f"{OCR_SCANNED_FOLDER}{filename}"
ocr_path = hf_hub_download(
repo_id=DATASET_REPO,
filename=filename,
repo_type="dataset",
token=hf_token
)
with open(ocr_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
st.error(f"Error loading {filename}: {e}")
return None
def get_unscanned_pdfs():
"""Get list of PDFs in answers/ that don't have OCR_SCANNED files"""
try:
answer_files = list_dataset_files("answers/")
answer_names = [f.split('/')[-1] for f in answer_files]
unscanned = []
for name in answer_names:
if not is_pdf_ocr_scanned(name):
unscanned.append(name)
return unscanned
except Exception as e:
st.error(f"Error getting unscanned PDFs: {e}")
return []
def get_context_for_ocr(pdf_name):
"""Load context from books/ and exams/ folders for better OCR"""
context_text = ""
try:
# Try to load from books folder
book_files = list_dataset_files("books/")
if book_files:
# Load first book as context (or implement smarter selection)
book_path, _ = download_with_cache(book_files[0])
if book_path:
book_text = extract_text_from_pdf(book_path)
if book_text:
context_text += f"BOOK CONTEXT:\n{book_text[:3000]}\n\n"
# Try to find matching exam
exam_name = pdf_name.replace('answers', 'exam').replace('Answers', 'Exam')
exam_files = list_dataset_files("exams/")
for exam_file in exam_files:
if exam_name in exam_file:
exam_path, _ = download_with_cache(exam_file)
if exam_path:
exam_text = extract_text_from_pdf(exam_path)
if exam_text:
context_text += f"EXAM CONTEXT:\n{exam_text[:3000]}\n\n"
break
except:
pass
return context_text
# ============================================================================
# CHUNKING FUNCTIONS
# ============================================================================
def chunk_text_fixed(text, chunk_size=150, overlap=30):
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
if chunk.strip():
chunks.append(chunk)
return chunks
def chunk_text_semantic(text, max_chunk_size=500):
paragraphs = text.split('\n\n')
chunks = []
current_chunk = []
current_size = 0
for para in paragraphs:
para_words = para.split()
para_size = len(para_words)
if current_size + para_size <= max_chunk_size:
current_chunk.extend(para_words)
current_size += para_size
else:
if current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = para_words
current_size = para_size
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def chunk_text_exercise_aware(text, max_chunk_size=500):
exercise_pattern = r'Exercise\s+\d+'
exercises = re.split(exercise_pattern, text)
chunks = []
for exercise in exercises:
if not exercise.strip():
continue
words = exercise.split()
if len(words) <= max_chunk_size:
chunks.append(exercise.strip())
else:
for i in range(0, len(words), max_chunk_size):
chunk = ' '.join(words[i:i + max_chunk_size])
if chunk.strip():
chunks.append(chunk)
return chunks
def chunk_text(text, chunk_size=150, overlap=30, strategy="Fixed-Size"):
if strategy == "Fixed-Size":
return chunk_text_fixed(text, chunk_size, overlap)
elif strategy == "Semantic":
return chunk_text_semantic(text, chunk_size)
elif strategy == "Exercise-Aware":
return chunk_text_exercise_aware(text, chunk_size)
else:
return chunk_text_fixed(text, chunk_size, overlap)
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def check_if_processed(qdrant, file_name, chunk_size=None, embedding_model=None, strategy="filename_only"):
try:
collection_info = qdrant.get_collection(collection_name=COLLECTION_NAME)
if collection_info.points_count == 0:
return False, 0
except:
return False, 0
try:
filter_conditions = [
FieldCondition(key="source_name", match=MatchValue(value=file_name))
]
if strategy in ["filename_settings", "filename_full"]:
if chunk_size is not None:
filter_conditions.append(
FieldCondition(key="chunk_size", match=MatchValue(value=chunk_size))
)
if strategy == "filename_full":
if embedding_model is not None:
filter_conditions.append(
FieldCondition(key="embedding_model", match=MatchValue(value=embedding_model))
)
count_result = qdrant.count(
collection_name=COLLECTION_NAME,
count_filter=Filter(must=filter_conditions)
)
return count_result.count > 0, count_result.count
except:
return False, 0
def list_dataset_files(folder_path):
try:
hf_token = os.getenv("HF_TOKEN")
all_files = list_repo_files(
repo_id=DATASET_REPO,
repo_type="dataset",
token=hf_token
)
return [f for f in all_files if f.startswith(folder_path) and f.endswith('.pdf')]
except Exception as e:
st.error(f"Error: {e}")
return []
def extract_text_from_pdf(pdf_path):
try:
with open(pdf_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
text = ""
for page_num, page in enumerate(reader.pages):
text += f"\n\n=== Page {page_num + 1} ===\n\n{page.extract_text()}"
return text
except Exception as e:
st.error(f"Error: {e}")
return None
def pdf_to_images(pdf_path):
try:
images = convert_from_path(pdf_path, dpi=200)
return images
except Exception as e:
st.error(f"Error: {e}")
st.info("πŸ’‘ Add 'poppler-utils' to packages.txt")
return []
def resize_image(image, max_size=(2048, 2048)):
image.thumbnail(max_size, Image.Resampling.LANCZOS)
return image
def image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
def ocr_with_claude(claude_client, image, context=""):
resized = resize_image(image.copy())
img_b64 = image_to_base64(resized)
prompt = f"""Transcribe handwritten math exam.
CONTEXT: {context[:2000] if context else ""}
Use LaTeX notation. Preserve structure.
OUTPUT: Transcription only."""
try:
message = claude_client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=4000,
messages=[{
"role": "user",
"content": [
{"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": img_b64}},
{"type": "text", "text": prompt}
]
}]
)
return message.content[0].text, message.usage.input_tokens + message.usage.output_tokens
except Exception as e:
st.error(f"OCR error: {e}")
return None, 0
def get_vector_count(qdrant):
try:
collection_info = qdrant.get_collection(collection_name=COLLECTION_NAME)
return collection_info.points_count
except:
return 0
def delete_vector_database(qdrant):
try:
qdrant.delete_collection(collection_name=COLLECTION_NAME)
return True
except Exception as e:
st.error(f"Error: {e}")
return False
def preview_public_dataset(dataset_info, sample_count):
try:
from datasets import load_dataset
hf_path = dataset_info['hf_path']
subset = dataset_info.get('subset')
split = dataset_info.get('split', 'train')
problem_field = dataset_info.get('problem_field')
solution_field = dataset_info.get('solution_field')
if subset:
dataset = load_dataset(hf_path, subset, split=split, streaming=True, trust_remote_code=True)
else:
dataset = load_dataset(hf_path, split=split, streaming=True, trust_remote_code=True)
samples = []
for i, item in enumerate(dataset):
if i >= sample_count:
break
problem = str(item.get(problem_field, 'N/A')) if problem_field else 'N/A'
solution = str(item.get(solution_field, 'N/A')) if solution_field else 'N/A'
samples.append({
'problem': problem[:1000],
'solution': solution[:1000]
})
return samples
except Exception as e:
st.error(f"Error: {e}")
return []
# ============================================================================
# OCR PROCESSING FUNCTIONS
# ============================================================================
def process_single_pdf_ocr(pdf_name):
"""Process a single PDF through OCR pipeline"""
if not claude:
st.error("❌ Claude client not initialized. Check your ANTHROPIC_API_KEY secret.")
return
try:
with st.spinner(f"Processing {pdf_name}..."):
# Download PDF
pdf_path, _ = download_with_cache(f"answers/{pdf_name}")
if not pdf_path:
st.error("Failed to download PDF")
return
# Get context from books and exams
with st.spinner("Loading context from books and exams..."):
context = get_context_for_ocr(pdf_name)
# Convert PDF to images
with st.spinner("Converting PDF to images..."):
images = pdf_to_images(pdf_path)
if not images:
st.error("Failed to convert PDF to images")
return
# OCR each page
all_text = []
total_tokens = 0
progress_bar = st.progress(0)
for idx, image in enumerate(images):
with st.spinner(f"OCR processing page {idx + 1}/{len(images)}..."):
page_text, tokens = ocr_with_claude(claude, image, context)
if page_text:
all_text.append(f"\n\n=== Page {idx + 1} ===\n\n{page_text}")
total_tokens += tokens
progress_bar.progress((idx + 1) / len(images))
progress_bar.empty()
# Combine all text
full_text = "\n".join(all_text)
# Save to OCR_SCANNED folder
with st.spinner("Saving to OCR_SCANNED folder..."):
success = save_ocr_scanned_to_hf(pdf_name, full_text)
# Also save metadata to cache
with st.spinner("Saving metadata..."):
save_ocr_to_hf(pdf_name, full_text, total_tokens)
if success:
cost = total_tokens * 0.000003
st.success(f"βœ… Successfully processed {pdf_name}")
st.info(f"πŸ“Š Tokens: {total_tokens:,} | πŸ’° Cost: ${cost:.4f}")
st.caption(f"πŸ’Ύ Saved to: {get_ocr_scanned_filename(pdf_name)}")
# Remove from unscanned list
if 'unscanned_pdfs' in st.session_state:
if pdf_name in st.session_state.unscanned_pdfs:
st.session_state.unscanned_pdfs.remove(pdf_name)
else:
st.error("Failed to save OCR result")
except Exception as e:
st.error(f"Error processing {pdf_name}: {e}")
def process_batch_ocr(pdf_list):
"""Process multiple PDFs in batch"""
if not claude:
st.error("❌ Claude client not initialized. Check your ANTHROPIC_API_KEY secret.")
return
total_cost = 0.0
total_tokens = 0
success_count = 0
for pdf_name in pdf_list:
st.markdown(f"### Processing: {pdf_name}")
try:
# Download PDF
pdf_path, _ = download_with_cache(f"answers/{pdf_name}")
if not pdf_path:
st.error(f"❌ Failed to download {pdf_name}")
continue
# Get context
context = get_context_for_ocr(pdf_name)
# Convert to images
images = pdf_to_images(pdf_path)
if not images:
st.error(f"❌ Failed to convert {pdf_name} to images")
continue
# OCR each page
all_text = []
pdf_tokens = 0
for idx, image in enumerate(images):
st.caption(f"Page {idx + 1}/{len(images)}...")
page_text, tokens = ocr_with_claude(claude, image, context)
if page_text:
all_text.append(f"\n\n=== Page {idx + 1} ===\n\n{page_text}")
pdf_tokens += tokens
full_text = "\n".join(all_text)
# Save to OCR_SCANNED
success = save_ocr_scanned_to_hf(pdf_name, full_text)
save_ocr_to_hf(pdf_name, full_text, pdf_tokens)
if success:
cost = pdf_tokens * 0.000003
total_cost += cost
total_tokens += pdf_tokens
success_count += 1
st.success(f"βœ… {pdf_name} - Tokens: {pdf_tokens:,} | Cost: ${cost:.4f}")
# Remove from unscanned list
if 'unscanned_pdfs' in st.session_state:
if pdf_name in st.session_state.unscanned_pdfs:
st.session_state.unscanned_pdfs.remove(pdf_name)
else:
st.error(f"❌ Failed to save {pdf_name}")
except Exception as e:
st.error(f"❌ Error with {pdf_name}: {e}")
st.markdown("---")
# Final summary
st.markdown("## πŸ“Š Batch Processing Complete")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("βœ… Success", f"{success_count}/{len(pdf_list)}")
with col2:
st.metric("πŸ“ Total Tokens", f"{total_tokens:,}")
with col3:
st.metric("πŸ’° Total Cost", f"${total_cost:.4f}")
# ============================================================================
# INITIALIZE (Memory-efficient)
# ============================================================================
# Only initialize what's absolutely needed at startup
try:
qdrant = get_qdrant_client()
claude = get_claude_client()
st.sidebar.success("βœ… Ready")
except Exception as e:
st.sidebar.error("⚠️ Setup")
st.error(f"❌ Initialization issue: {e}")
st.info("""
**Please check:**
- QDRANT_URL is set correctly
- QDRANT_API_KEY is set correctly
- ANTHROPIC_API_KEY is set correctly
- HF_TOKEN is set correctly
Go to your Streamlit Cloud settings β†’ Secrets
""")
# Don't stop the app, allow user to see the error and fix it
qdrant = None
claude = None
# ============================================================================
# SIDEBAR
# ============================================================================
st.sidebar.title("πŸŽ“ Math AI v5.0")
# Show current model name without loading it
current_model_name = st.session_state.embedding_model
model_display_name = "Unknown"
for key, value in EMBEDDING_MODELS.items():
if value["name"] == current_model_name:
model_display_name = key
break
st.sidebar.info(f"**Model:** {model_display_name}")
st.sidebar.caption(f"**Chunking:** {st.session_state.chunking_strategy}")
st.sidebar.markdown("---")
# Use lazy loading for metrics to prevent startup issues
try:
vector_count = get_vector_count(qdrant)
st.sidebar.metric("πŸ“Š Vectors", f"{vector_count:,}")
except Exception as e:
st.sidebar.metric("πŸ“Š Vectors", "---")
st.sidebar.caption("Click refresh to load")
st.sidebar.markdown("---")
# Lazy load OCR stats
try:
scanned_files_count = len(list_ocr_scanned_files())
ocr_count, ocr_saved = get_ocr_cache_stats_hf()
st.sidebar.metric("πŸ“„ OCR Scanned", scanned_files_count)
st.sidebar.metric("πŸ’° Total Spent", f"${ocr_saved:.2f}")
st.sidebar.caption("πŸ’‘ OCR saved to OCR_SCANNED/")
except Exception as e:
st.sidebar.metric("πŸ“„ OCR Scanned", "---")
st.sidebar.caption("Loading...")
if st.sidebar.button("πŸ—‘οΈ Clear PDF"):
clear_pdf_cache()
st.rerun()
st.sidebar.markdown("---")
if st.sidebar.button("πŸ—‘οΈ DELETE VDB"):
if st.sidebar.checkbox("⚠️ Confirm"):
delete_vector_database(qdrant)
st.sidebar.success("Deleted!")
st.rerun()
# ============================================================================
# TABS
# ============================================================================
tab1, tab2, tab3, tab4 = st.tabs([
"πŸ“ Data Preparation",
"πŸ”’ Embedding Pipeline",
"πŸ” Retrieval & Search",
"πŸ“ˆ Statistics"
])
# ============================================================================
# TAB 1: DATA PREPARATION
# ============================================================================
with tab1:
st.title("πŸ“ Data Preparation")
prep_tabs = st.tabs(["πŸ“₯ Loading & Preview", "🌐 Public Datasets", "πŸ–ŠοΈ OCR Processing"])
with prep_tabs[0]:
st.header("πŸ“₯ Data Loading & Preview")
st.info("Upload PDFs to HuggingFace and preview them here. OCR_SCANNED folder contains processed answer sheets.")
folder = st.selectbox("Folder:", ["books/", "exams/", "answers/", "OCR_SCANNED/"])
if st.button("πŸ” Scan", key="scan_folder_button"):
with st.spinner(f"Scanning {folder}..."):
if folder == "OCR_SCANNED/":
files = list_ocr_scanned_files()
else:
files = list_dataset_files(folder)
if files:
st.session_state.scanned = files
st.session_state.scanned_folder = folder
st.success(f"βœ… Found {len(files)} files in {folder}")
else:
st.warning(f"No files found in {folder}")
st.session_state.scanned = []
if 'scanned' in st.session_state and st.session_state.scanned:
st.subheader(f"πŸ“‚ Files in {st.session_state.get('scanned_folder', folder)}")
for file in st.session_state.scanned:
name = file.split('/')[-1]
with st.expander(f"πŸ“„ {name}"):
col1, col2 = st.columns([3, 1])
with col1:
st.caption(f"**Path:** `{file}`")
with col2:
if st.button(f"πŸ‘οΈ Preview", key=f"preview_{name}"):
st.session_state[f'show_preview_{name}'] = True
# Show preview if button was clicked
if st.session_state.get(f'show_preview_{name}', False):
try:
if st.session_state.get('scanned_folder') == "OCR_SCANNED/":
# Load from OCR_SCANNED folder
text = load_ocr_scanned_from_hf(file)
if text:
st.success("βœ… OCR file loaded successfully")
st.text_area("Content:", text[:1000], height=200, key=f"content_{name}")
# Show stats
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Characters", len(text))
with col2:
st.metric("Words", len(text.split()))
with col3:
latex_count = text.count('\\') + text.count('$')
st.metric("LaTeX", latex_count)
else:
st.error("❌ Could not load file")
elif st.session_state.get('scanned_folder') == "answers/":
# Check if OCR version exists
if is_pdf_ocr_scanned(name):
st.success("βœ… This file has been OCR processed!")
if st.button("View OCR Version", key=f"view_ocr_{name}"):
text = load_ocr_scanned_from_hf(name)
if text:
st.text_area("OCR Version:", text[:500], height=150, key=f"ocr_{name}")
else:
st.warning("⏳ This file needs OCR processing")
# Show PDF preview
path, cached = download_with_cache(file)
if path:
text = extract_text_from_pdf(path)
if text:
st.text_area("PDF Text:", text[:500], height=150, key=f"pdf_{name}")
st.caption(f"{'πŸ”„ Downloaded' if not cached else 'πŸ’Ύ From cache'}")
else:
# Regular PDF preview for books/exams
path, cached = download_with_cache(file)
if path:
text = extract_text_from_pdf(path)
if text:
st.text_area("Text:", text[:500], height=150, key=f"text_{name}")
st.caption(f"{'πŸ”„ Downloaded' if not cached else 'πŸ’Ύ From cache'}")
except Exception as e:
st.error(f"Error previewing file: {e}")
with prep_tabs[1]:
st.header("🌐 Public Dataset Preview")
st.info("Preview sample problems from public math datasets before loading")
# Select category
category = st.selectbox("πŸ“š Select Category:", list(PUBLIC_DATASETS.keys()))
category_info = PUBLIC_DATASETS[category]
st.caption(f"*{category_info['description']}*")
# Select dataset within category
dataset_names = [d['name'] for d in category_info['datasets']]
selected_dataset_name = st.selectbox("πŸ“Š Select Dataset:", dataset_names)
# Find the selected dataset info
dataset_info = None
for d in category_info['datasets']:
if d['name'] == selected_dataset_name:
dataset_info = d
break
if dataset_info:
# Display dataset info
col1, col2 = st.columns(2)
with col1:
st.metric("Relevance", dataset_info['relevance'])
st.caption(f"**Topics:** {dataset_info['topics']}")
with col2:
st.caption(f"**HF Path:** `{dataset_info['hf_path']}`")
if dataset_info.get('subset'):
st.caption(f"**Subset:** `{dataset_info['subset']}`")
# Sample count selector
sample_count = st.slider("Number of samples to preview:", 1, 10, 3)
# Preview button
if st.button("πŸ” Preview Samples", key=f"preview_{selected_dataset_name}"):
with st.spinner("Loading samples..."):
samples = preview_public_dataset(dataset_info, sample_count)
if samples:
st.success(f"βœ… Loaded {len(samples)} samples")
for i, sample in enumerate(samples, 1):
with st.expander(f"πŸ“ Sample {i}", expanded=(i==1)):
st.markdown("**Problem:**")
st.code(sample['problem'], language="text")
st.markdown("**Solution:**")
st.code(sample['solution'], language="text")
else:
st.error("Failed to load samples")
st.markdown("---")
st.info("πŸ’‘ **Tip:** After previewing, go to the Embedding Pipeline tab to load this dataset into your vector database")
with prep_tabs[2]:
st.header("πŸ–ŠοΈ OCR Processing")
st.info("πŸ“‹ AI-powered OCR with context from books/ and exams/ folders. Results saved to OCR_SCANNED/ folder.")
# Debug/Test Section
with st.expander("πŸ”§ Debug & Test Connection"):
st.subheader("Test HuggingFace Connection")
if st.button("πŸ§ͺ Test Connection"):
try:
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
st.error("❌ HF_TOKEN not found in environment")
else:
st.success("βœ… HF_TOKEN found")
# Test listing files
all_files = list_repo_files(
repo_id=DATASET_REPO,
repo_type="dataset",
token=hf_token
)
st.success(f"βœ… Connected to {DATASET_REPO}")
st.caption(f"Total files in dataset: {len(all_files)}")
# Show OCR_SCANNED files
ocr_files = [f for f in all_files if f.startswith(OCR_SCANNED_FOLDER)]
st.info(f"πŸ“ Files in OCR_SCANNED/: {len(ocr_files)}")
if ocr_files:
st.write("Files found:")
for f in ocr_files:
st.code(f)
# Show answers files
answer_files = [f for f in all_files if f.startswith("answers/")]
st.info(f"πŸ“ Files in answers/: {len(answer_files)}")
except Exception as e:
st.error(f"❌ Connection test failed: {e}")
import traceback
st.code(traceback.format_exc())
st.markdown("---")
# Display statistics
col1, col2, col3 = st.columns(3)
with col1:
total_answers = len(list_dataset_files("answers/"))
st.metric("πŸ“„ Total Answers", total_answers)
with col2:
scanned_count = len(list_ocr_scanned_files())
st.metric("βœ… Scanned", scanned_count)
with col3:
ocr_count, ocr_cost = get_ocr_cache_stats_hf()
st.metric("πŸ’° Total Cost", f"${ocr_cost:.2f}")
st.markdown("---")
# Scan for unscanned PDFs
if st.button("πŸ” Scan for Unprocessed Files", use_container_width=True):
with st.spinner("Scanning answers/ folder..."):
unscanned = get_unscanned_pdfs()
st.session_state.unscanned_pdfs = unscanned
if unscanned:
st.success(f"βœ… Found {len(unscanned)} files that need OCR processing")
else:
st.info("πŸŽ‰ All files are already processed!")
# Display unscanned files
if 'unscanned_pdfs' in st.session_state and st.session_state.unscanned_pdfs:
st.subheader(f"πŸ“‹ Unprocessed Files ({len(st.session_state.unscanned_pdfs)})")
for pdf_name in st.session_state.unscanned_pdfs:
with st.expander(f"πŸ“„ {pdf_name}", expanded=False):
col1, col2 = st.columns([3, 1])
with col1:
st.caption(f"Status: ⏳ Waiting for OCR")
st.caption(f"Will be saved to: `{get_ocr_scanned_filename(pdf_name)}`")
with col2:
if st.button("πŸš€ Process", key=f"ocr_{pdf_name}", use_container_width=True):
process_single_pdf_ocr(pdf_name)
st.markdown("---")
# Batch processing
st.subheader("⚑ Batch Processing")
col1, col2 = st.columns(2)
with col1:
batch_size = st.number_input(
"Number of files to process:",
min_value=1,
max_value=len(st.session_state.unscanned_pdfs),
value=min(3, len(st.session_state.unscanned_pdfs))
)
with col2:
st.metric("Estimated Cost", f"${batch_size * 0.05:.2f}")
st.caption("~$0.05 per file average")
if st.button("πŸš€ Process Batch", type="primary", use_container_width=True):
process_batch_ocr(st.session_state.unscanned_pdfs[:batch_size])
st.markdown("---")
# Preview OCR_SCANNED files
st.subheader("πŸ“‚ OCR_SCANNED Files Preview")
scanned_files = list_ocr_scanned_files()
if scanned_files:
st.success(f"βœ… Found {len(scanned_files)} OCR scanned files")
selected_file = st.selectbox(
"Select file to preview:",
scanned_files,
format_func=lambda x: x.split('/')[-1]
)
if st.button("πŸ‘οΈ Preview OCR Result", use_container_width=True):
with st.spinner("Loading OCR file..."):
ocr_text = load_ocr_scanned_from_hf(selected_file)
if ocr_text:
st.success("βœ… File loaded successfully")
st.text_area("OCR Transcription:", ocr_text, height=400)
# Quality indicators
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Characters", f"{len(ocr_text):,}")
with col2:
st.metric("Words", f"{len(ocr_text.split()):,}")
with col3:
latex_count = ocr_text.count('\\') + ocr_text.count('$')
st.metric("LaTeX Elements", latex_count)
else:
st.error("❌ Failed to load OCR file")
else:
st.info("πŸ“­ No scanned files yet. Process some files above!")
st.markdown("---")
# Reset/Management options
st.subheader("πŸ”§ Management")
col1, col2 = st.columns(2)
with col1:
if st.button("πŸ”„ Refresh Status", use_container_width=True):
# Clear cached data
if 'unscanned_pdfs' in st.session_state:
del st.session_state.unscanned_pdfs
st.rerun()
with col2:
with st.popover("⚠️ Reset OCR Cache"):
st.warning("This will clear the metadata cache (JSON files) but NOT the OCR_SCANNED files.")
st.caption("OCR_SCANNED files are permanent and must be deleted manually from HuggingFace.")
if st.button("Confirm Reset Cache"):
# This would clear the ocr_cache folder, not OCR_SCANNED
st.info("Cache metadata reset (OCR_SCANNED files preserved)")
st.rerun()
# ============================================================================
# TAB 2: EMBEDDING PIPELINE
# ============================================================================
with tab2:
st.title("πŸ”’ Embedding Pipeline")
st.header("βš™οΈ Configuration")
c1, c2 = st.columns(2)
with c1:
strategy = st.selectbox("Chunking:", list(CHUNKING_STRATEGIES.keys()))
chunk_size = st.slider("Size:", 50, 800, st.session_state.chunk_size)
if st.button("πŸ’Ύ Save"):
st.session_state.chunking_strategy = strategy
st.session_state.chunk_size = chunk_size
st.success("Saved!")
with c2:
model = st.selectbox("Model:", list(EMBEDDING_MODELS.keys()))
if st.button("πŸ’Ύ Save Model"):
st.session_state.embedding_model = EMBEDDING_MODELS[model]['name']
st.success("Saved!")
# ============================================================================
# TAB 3: RETRIEVAL
# ============================================================================
with tab3:
st.title("πŸ” Retrieval & Search")
query = st.text_area("Query:", height=150)
top_k = st.slider("Top K:", 3, 20, 5)
if st.button("πŸ” SEARCH") and query:
if not qdrant:
st.error("❌ Qdrant client not initialized. Check your secrets.")
st.stop()
embedder = get_embedding_model(st.session_state.embedding_model)
query_emb = embedder.encode(query)
try:
results = qdrant.search(
collection_name=COLLECTION_NAME,
query_vector=query_emb.tolist(),
limit=top_k
)
if results:
st.success(f"Found {len(results)} results")
for i, r in enumerate(results, 1):
with st.expander(f"Result {i} ({r.score*100:.1f}%)"):
st.text_area("Content:", r.payload['content'], height=150, key=f"r_{i}")
else:
st.warning("No results")
except:
st.error("Search failed")
# ============================================================================
# TAB 4: STATISTICS
# ============================================================================
with tab4:
st.title("πŸ“ˆ Statistics")
try:
total = get_vector_count(qdrant)
st.metric("Total Vectors", f"{total:,}")
except:
st.info("No data yet")