Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| import time | |
| import base64 | |
| import hashlib | |
| from io import BytesIO | |
| from PIL import Image | |
| import PyPDF2 | |
| from pdf2image import convert_from_path | |
| from anthropic import Anthropic | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import hf_hub_download, list_repo_files, upload_file | |
| from pathlib import Path | |
| import shutil | |
| import json | |
| import re | |
| # ============================================================================ | |
| # PRODUCTION MATH AI SYSTEM v5.0 - FINAL | |
| # ============================================================================ | |
| st.set_page_config( | |
| page_title="Math AI System", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| COLLECTION_NAME = "math_knowledge_base" | |
| DATASET_REPO = "Hebaelsayed/math-ai-documents" # β CHANGE THIS! | |
| CACHE_DIR = Path("/tmp/hf_cache") | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| OCR_CACHE_FOLDER = "ocr_cache/" | |
| OCR_SCANNED_FOLDER = "OCR_SCANNED/" # New folder for scanned documents | |
| # ============================================================================ | |
| # EMBEDDING MODELS (ALL FREE) | |
| # ============================================================================ | |
| EMBEDDING_MODELS = { | |
| "MiniLM-L6 (Fast, 384D)": { | |
| "name": "sentence-transformers/all-MiniLM-L6-v2", | |
| "dimensions": 384, | |
| "speed": "Very Fast", | |
| "quality": "Good", | |
| "cost": "FREE" | |
| }, | |
| "MiniLM-L12 (Balanced, 384D)": { | |
| "name": "sentence-transformers/all-MiniLM-L12-v2", | |
| "dimensions": 384, | |
| "speed": "Fast", | |
| "quality": "Better", | |
| "cost": "FREE" | |
| }, | |
| "MPNet (Best, 768D)": { | |
| "name": "sentence-transformers/all-mpnet-base-v2", | |
| "dimensions": 768, | |
| "speed": "Medium", | |
| "quality": "Excellent", | |
| "cost": "FREE" | |
| }, | |
| "BGE-Large (SOTA, 1024D)": { | |
| "name": "BAAI/bge-large-en-v1.5", | |
| "dimensions": 1024, | |
| "speed": "Slower", | |
| "quality": "State-of-the-art", | |
| "cost": "FREE" | |
| } | |
| } | |
| # ============================================================================ | |
| # CHUNKING STRATEGIES (ALL FREE) | |
| # ============================================================================ | |
| CHUNKING_STRATEGIES = { | |
| "Fixed-Size": { | |
| "description": "Simple word-based chunking with overlap", | |
| "pros": "Fast, predictable", | |
| "cons": "May split equations", | |
| "cost": "FREE" | |
| }, | |
| "Semantic": { | |
| "description": "Split at paragraph boundaries", | |
| "pros": "Preserves meaning", | |
| "cons": "Variable sizes", | |
| "cost": "FREE" | |
| }, | |
| "Exercise-Aware": { | |
| "description": "Split by Exercise 1, 2, etc.", | |
| "pros": "Perfect for exams", | |
| "cons": "Structured docs only", | |
| "cost": "FREE" | |
| }, | |
| "LaTeX-Aware": { | |
| "description": "Preserve math expressions", | |
| "pros": "Never splits equations", | |
| "cons": "More complex", | |
| "cost": "FREE" | |
| } | |
| } | |
| # ============================================================================ | |
| # PUBLIC DATASETS (FIXED & WORKING) | |
| # ============================================================================ | |
| PUBLIC_DATASETS = { | |
| "Linear Algebra": { | |
| "description": "Matrices, Eigenvalues, SVD", | |
| "datasets": [ | |
| { | |
| "name": "Hendrycks MATH (Precalculus)", | |
| "hf_path": "EleutherAI/hendrycks_math", | |
| "subset": "precalculus", | |
| "problem_field": "problem", | |
| "solution_field": "solution", | |
| "relevance": "βββββ", | |
| "topics": "Matrices, eigenvalues" | |
| }, | |
| { | |
| "name": "DART-Math-Hard", | |
| "hf_path": "hkust-nlp/dart-math-hard", | |
| "subset": None, | |
| "problem_field": "query", | |
| "solution_field": "response", | |
| "relevance": "ββββ", | |
| "topics": "Difficult reasoning" | |
| }, | |
| { | |
| "name": "MATH-Hard", | |
| "hf_path": "lighteval/MATH-Hard", | |
| "subset": None, | |
| "problem_field": "problem", | |
| "solution_field": "solution", | |
| "relevance": "ββββ", | |
| "topics": "Competition math" | |
| } | |
| ] | |
| }, | |
| "Optimization & Calculus": { | |
| "description": "Gradients, Convex Optimization", | |
| "datasets": [ | |
| { | |
| "name": "Hendrycks MATH (Algebra)", | |
| "hf_path": "EleutherAI/hendrycks_math", | |
| "subset": "intermediate_algebra", | |
| "problem_field": "problem", | |
| "solution_field": "solution", | |
| "relevance": "βββββ", | |
| "topics": "Calculus, optimization" | |
| } | |
| ] | |
| } | |
| } | |
| # ============================================================================ | |
| # SESSION STATE | |
| # ============================================================================ | |
| if 'embedding_model' not in st.session_state: | |
| st.session_state.embedding_model = EMBEDDING_MODELS["MiniLM-L6 (Fast, 384D)"]["name"] | |
| if 'chunking_strategy' not in st.session_state: | |
| st.session_state.chunking_strategy = "Fixed-Size" | |
| if 'chunk_size' not in st.session_state: | |
| st.session_state.chunk_size = 150 | |
| if 'chunk_overlap' not in st.session_state: | |
| st.session_state.chunk_overlap = 30 | |
| # ============================================================================ | |
| # CACHED RESOURCES | |
| # ============================================================================ | |
| def get_qdrant_client(): | |
| return QdrantClient( | |
| url=os.getenv("QDRANT_URL"), | |
| api_key=os.getenv("QDRANT_API_KEY") | |
| ) | |
| def get_claude_client(): | |
| return Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) | |
| def get_embedding_model(model_name): | |
| """Lazy load embedding model only when needed""" | |
| with st.spinner(f"Loading embedding model: {model_name}..."): | |
| return SentenceTransformer(model_name) | |
| # ============================================================================ | |
| # CACHE FUNCTIONS | |
| # ============================================================================ | |
| def get_cache_path(file_path): | |
| file_hash = hashlib.md5(file_path.encode()).hexdigest() | |
| return CACHE_DIR / f"{file_hash}.pdf" | |
| def is_file_cached(file_path): | |
| return get_cache_path(file_path).exists() | |
| def download_with_cache(file_path): | |
| cache_path = get_cache_path(file_path) | |
| if cache_path.exists(): | |
| return str(cache_path), True | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| downloaded_path = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename=file_path, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| shutil.copy(downloaded_path, cache_path) | |
| return str(cache_path), False | |
| except Exception as e: | |
| st.error(f"Download error: {e}") | |
| return None, False | |
| def clear_pdf_cache(): | |
| if CACHE_DIR.exists(): | |
| shutil.rmtree(CACHE_DIR) | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| return True | |
| def get_ocr_filename(pdf_name): | |
| return f"{OCR_CACHE_FOLDER}{pdf_name}.json" | |
| def is_ocr_cached_hf(pdf_name): | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| all_files = list_repo_files( | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| return get_ocr_filename(pdf_name) in all_files | |
| except: | |
| return False | |
| def save_ocr_to_hf(pdf_name, transcribed_text, total_tokens): | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| cache_data = { | |
| "file_name": pdf_name, | |
| "transcribed_text": transcribed_text, | |
| "total_tokens": total_tokens, | |
| "timestamp": time.time(), | |
| "cost": total_tokens * 0.000003 | |
| } | |
| temp_path = Path("/tmp") / f"{pdf_name}_ocr.json" | |
| with open(temp_path, 'w', encoding='utf-8') as f: | |
| json.dump(cache_data, f, ensure_ascii=False, indent=2) | |
| upload_file( | |
| path_or_fileobj=str(temp_path), | |
| path_in_repo=get_ocr_filename(pdf_name), | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| temp_path.unlink() | |
| return True | |
| except Exception as e: | |
| st.warning(f"OCR save failed: {e}") | |
| return False | |
| def load_ocr_from_hf(pdf_name): | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| ocr_path = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename=get_ocr_filename(pdf_name), | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| with open(ocr_path, 'r', encoding='utf-8') as f: | |
| cache_data = json.load(f) | |
| return cache_data.get('transcribed_text'), cache_data.get('total_tokens', 0) | |
| except: | |
| return None, 0 | |
| def get_ocr_cache_stats_hf(): | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| all_files = list_repo_files( | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| ocr_files = [f for f in all_files if f.startswith(OCR_CACHE_FOLDER)] | |
| total_cost = 0.0 | |
| for ocr_file in ocr_files: | |
| try: | |
| ocr_path = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename=ocr_file, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| with open(ocr_path, 'r') as f: | |
| data = json.load(f) | |
| total_cost += data.get('cost', 0) | |
| except: | |
| pass | |
| return len(ocr_files), total_cost | |
| except: | |
| return 0, 0.0 | |
| # ============================================================================ | |
| # OCR_SCANNED FOLDER FUNCTIONS | |
| # ============================================================================ | |
| def list_ocr_scanned_files(): | |
| """List all files in OCR_SCANNED folder""" | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| return [] | |
| all_files = list_repo_files( | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| # Get all files in OCR_SCANNED, not just .txt (for compatibility) | |
| ocr_files = [f for f in all_files if f.startswith(OCR_SCANNED_FOLDER) and not f.endswith('/')] | |
| return ocr_files | |
| except Exception as e: | |
| # Don't show error at startup, only when actively used | |
| return [] | |
| def get_ocr_scanned_filename(pdf_name): | |
| """Generate OCR_SCANNED filename from PDF name""" | |
| base_name = pdf_name.replace('.pdf', '') | |
| return f"{OCR_SCANNED_FOLDER}{base_name}.txt" | |
| def is_pdf_ocr_scanned(pdf_name): | |
| """Check if PDF has corresponding OCR_SCANNED file""" | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| all_files = list_repo_files( | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| return get_ocr_scanned_filename(pdf_name) in all_files | |
| except Exception as e: | |
| st.error(f"Error checking OCR status for {pdf_name}: {e}") | |
| return False | |
| def save_ocr_scanned_to_hf(pdf_name, transcribed_text): | |
| """Save OCR transcription to OCR_SCANNED folder as .txt""" | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| # Create temporary txt file | |
| temp_path = Path("/tmp") / f"{pdf_name}_ocr.txt" | |
| with open(temp_path, 'w', encoding='utf-8') as f: | |
| f.write(transcribed_text) | |
| # Upload to HuggingFace | |
| upload_file( | |
| path_or_fileobj=str(temp_path), | |
| path_in_repo=get_ocr_scanned_filename(pdf_name), | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| temp_path.unlink() | |
| return True | |
| except Exception as e: | |
| st.warning(f"OCR save to OCR_SCANNED failed: {e}") | |
| return False | |
| def load_ocr_scanned_from_hf(filename): | |
| """Load OCR transcription from OCR_SCANNED folder""" | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| # If filename doesn't include folder path, add it | |
| if not filename.startswith(OCR_SCANNED_FOLDER): | |
| if filename.endswith('.pdf'): | |
| filename = get_ocr_scanned_filename(filename) | |
| else: | |
| filename = f"{OCR_SCANNED_FOLDER}{filename}" | |
| ocr_path = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename=filename, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| with open(ocr_path, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| except Exception as e: | |
| st.error(f"Error loading {filename}: {e}") | |
| return None | |
| def get_unscanned_pdfs(): | |
| """Get list of PDFs in answers/ that don't have OCR_SCANNED files""" | |
| try: | |
| answer_files = list_dataset_files("answers/") | |
| answer_names = [f.split('/')[-1] for f in answer_files] | |
| unscanned = [] | |
| for name in answer_names: | |
| if not is_pdf_ocr_scanned(name): | |
| unscanned.append(name) | |
| return unscanned | |
| except Exception as e: | |
| st.error(f"Error getting unscanned PDFs: {e}") | |
| return [] | |
| def get_context_for_ocr(pdf_name): | |
| """Load context from books/ and exams/ folders for better OCR""" | |
| context_text = "" | |
| try: | |
| # Try to load from books folder | |
| book_files = list_dataset_files("books/") | |
| if book_files: | |
| # Load first book as context (or implement smarter selection) | |
| book_path, _ = download_with_cache(book_files[0]) | |
| if book_path: | |
| book_text = extract_text_from_pdf(book_path) | |
| if book_text: | |
| context_text += f"BOOK CONTEXT:\n{book_text[:3000]}\n\n" | |
| # Try to find matching exam | |
| exam_name = pdf_name.replace('answers', 'exam').replace('Answers', 'Exam') | |
| exam_files = list_dataset_files("exams/") | |
| for exam_file in exam_files: | |
| if exam_name in exam_file: | |
| exam_path, _ = download_with_cache(exam_file) | |
| if exam_path: | |
| exam_text = extract_text_from_pdf(exam_path) | |
| if exam_text: | |
| context_text += f"EXAM CONTEXT:\n{exam_text[:3000]}\n\n" | |
| break | |
| except: | |
| pass | |
| return context_text | |
| # ============================================================================ | |
| # CHUNKING FUNCTIONS | |
| # ============================================================================ | |
| def chunk_text_fixed(text, chunk_size=150, overlap=30): | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = ' '.join(words[i:i + chunk_size]) | |
| if chunk.strip(): | |
| chunks.append(chunk) | |
| return chunks | |
| def chunk_text_semantic(text, max_chunk_size=500): | |
| paragraphs = text.split('\n\n') | |
| chunks = [] | |
| current_chunk = [] | |
| current_size = 0 | |
| for para in paragraphs: | |
| para_words = para.split() | |
| para_size = len(para_words) | |
| if current_size + para_size <= max_chunk_size: | |
| current_chunk.extend(para_words) | |
| current_size += para_size | |
| else: | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| current_chunk = para_words | |
| current_size = para_size | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| return chunks | |
| def chunk_text_exercise_aware(text, max_chunk_size=500): | |
| exercise_pattern = r'Exercise\s+\d+' | |
| exercises = re.split(exercise_pattern, text) | |
| chunks = [] | |
| for exercise in exercises: | |
| if not exercise.strip(): | |
| continue | |
| words = exercise.split() | |
| if len(words) <= max_chunk_size: | |
| chunks.append(exercise.strip()) | |
| else: | |
| for i in range(0, len(words), max_chunk_size): | |
| chunk = ' '.join(words[i:i + max_chunk_size]) | |
| if chunk.strip(): | |
| chunks.append(chunk) | |
| return chunks | |
| def chunk_text(text, chunk_size=150, overlap=30, strategy="Fixed-Size"): | |
| if strategy == "Fixed-Size": | |
| return chunk_text_fixed(text, chunk_size, overlap) | |
| elif strategy == "Semantic": | |
| return chunk_text_semantic(text, chunk_size) | |
| elif strategy == "Exercise-Aware": | |
| return chunk_text_exercise_aware(text, chunk_size) | |
| else: | |
| return chunk_text_fixed(text, chunk_size, overlap) | |
| # ============================================================================ | |
| # HELPER FUNCTIONS | |
| # ============================================================================ | |
| def check_if_processed(qdrant, file_name, chunk_size=None, embedding_model=None, strategy="filename_only"): | |
| try: | |
| collection_info = qdrant.get_collection(collection_name=COLLECTION_NAME) | |
| if collection_info.points_count == 0: | |
| return False, 0 | |
| except: | |
| return False, 0 | |
| try: | |
| filter_conditions = [ | |
| FieldCondition(key="source_name", match=MatchValue(value=file_name)) | |
| ] | |
| if strategy in ["filename_settings", "filename_full"]: | |
| if chunk_size is not None: | |
| filter_conditions.append( | |
| FieldCondition(key="chunk_size", match=MatchValue(value=chunk_size)) | |
| ) | |
| if strategy == "filename_full": | |
| if embedding_model is not None: | |
| filter_conditions.append( | |
| FieldCondition(key="embedding_model", match=MatchValue(value=embedding_model)) | |
| ) | |
| count_result = qdrant.count( | |
| collection_name=COLLECTION_NAME, | |
| count_filter=Filter(must=filter_conditions) | |
| ) | |
| return count_result.count > 0, count_result.count | |
| except: | |
| return False, 0 | |
| def list_dataset_files(folder_path): | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| all_files = list_repo_files( | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| return [f for f in all_files if f.startswith(folder_path) and f.endswith('.pdf')] | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| return [] | |
| def extract_text_from_pdf(pdf_path): | |
| try: | |
| with open(pdf_path, 'rb') as file: | |
| reader = PyPDF2.PdfReader(file) | |
| text = "" | |
| for page_num, page in enumerate(reader.pages): | |
| text += f"\n\n=== Page {page_num + 1} ===\n\n{page.extract_text()}" | |
| return text | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| return None | |
| def pdf_to_images(pdf_path): | |
| try: | |
| images = convert_from_path(pdf_path, dpi=200) | |
| return images | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| st.info("π‘ Add 'poppler-utils' to packages.txt") | |
| return [] | |
| def resize_image(image, max_size=(2048, 2048)): | |
| image.thumbnail(max_size, Image.Resampling.LANCZOS) | |
| return image | |
| def image_to_base64(image): | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode() | |
| def ocr_with_claude(claude_client, image, context=""): | |
| resized = resize_image(image.copy()) | |
| img_b64 = image_to_base64(resized) | |
| prompt = f"""Transcribe handwritten math exam. | |
| CONTEXT: {context[:2000] if context else ""} | |
| Use LaTeX notation. Preserve structure. | |
| OUTPUT: Transcription only.""" | |
| try: | |
| message = claude_client.messages.create( | |
| model="claude-sonnet-4-20250514", | |
| max_tokens=4000, | |
| messages=[{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": img_b64}}, | |
| {"type": "text", "text": prompt} | |
| ] | |
| }] | |
| ) | |
| return message.content[0].text, message.usage.input_tokens + message.usage.output_tokens | |
| except Exception as e: | |
| st.error(f"OCR error: {e}") | |
| return None, 0 | |
| def get_vector_count(qdrant): | |
| try: | |
| collection_info = qdrant.get_collection(collection_name=COLLECTION_NAME) | |
| return collection_info.points_count | |
| except: | |
| return 0 | |
| def delete_vector_database(qdrant): | |
| try: | |
| qdrant.delete_collection(collection_name=COLLECTION_NAME) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| return False | |
| def preview_public_dataset(dataset_info, sample_count): | |
| try: | |
| from datasets import load_dataset | |
| hf_path = dataset_info['hf_path'] | |
| subset = dataset_info.get('subset') | |
| split = dataset_info.get('split', 'train') | |
| problem_field = dataset_info.get('problem_field') | |
| solution_field = dataset_info.get('solution_field') | |
| if subset: | |
| dataset = load_dataset(hf_path, subset, split=split, streaming=True, trust_remote_code=True) | |
| else: | |
| dataset = load_dataset(hf_path, split=split, streaming=True, trust_remote_code=True) | |
| samples = [] | |
| for i, item in enumerate(dataset): | |
| if i >= sample_count: | |
| break | |
| problem = str(item.get(problem_field, 'N/A')) if problem_field else 'N/A' | |
| solution = str(item.get(solution_field, 'N/A')) if solution_field else 'N/A' | |
| samples.append({ | |
| 'problem': problem[:1000], | |
| 'solution': solution[:1000] | |
| }) | |
| return samples | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| return [] | |
| # ============================================================================ | |
| # OCR PROCESSING FUNCTIONS | |
| # ============================================================================ | |
| def process_single_pdf_ocr(pdf_name): | |
| """Process a single PDF through OCR pipeline""" | |
| if not claude: | |
| st.error("β Claude client not initialized. Check your ANTHROPIC_API_KEY secret.") | |
| return | |
| try: | |
| with st.spinner(f"Processing {pdf_name}..."): | |
| # Download PDF | |
| pdf_path, _ = download_with_cache(f"answers/{pdf_name}") | |
| if not pdf_path: | |
| st.error("Failed to download PDF") | |
| return | |
| # Get context from books and exams | |
| with st.spinner("Loading context from books and exams..."): | |
| context = get_context_for_ocr(pdf_name) | |
| # Convert PDF to images | |
| with st.spinner("Converting PDF to images..."): | |
| images = pdf_to_images(pdf_path) | |
| if not images: | |
| st.error("Failed to convert PDF to images") | |
| return | |
| # OCR each page | |
| all_text = [] | |
| total_tokens = 0 | |
| progress_bar = st.progress(0) | |
| for idx, image in enumerate(images): | |
| with st.spinner(f"OCR processing page {idx + 1}/{len(images)}..."): | |
| page_text, tokens = ocr_with_claude(claude, image, context) | |
| if page_text: | |
| all_text.append(f"\n\n=== Page {idx + 1} ===\n\n{page_text}") | |
| total_tokens += tokens | |
| progress_bar.progress((idx + 1) / len(images)) | |
| progress_bar.empty() | |
| # Combine all text | |
| full_text = "\n".join(all_text) | |
| # Save to OCR_SCANNED folder | |
| with st.spinner("Saving to OCR_SCANNED folder..."): | |
| success = save_ocr_scanned_to_hf(pdf_name, full_text) | |
| # Also save metadata to cache | |
| with st.spinner("Saving metadata..."): | |
| save_ocr_to_hf(pdf_name, full_text, total_tokens) | |
| if success: | |
| cost = total_tokens * 0.000003 | |
| st.success(f"β Successfully processed {pdf_name}") | |
| st.info(f"π Tokens: {total_tokens:,} | π° Cost: ${cost:.4f}") | |
| st.caption(f"πΎ Saved to: {get_ocr_scanned_filename(pdf_name)}") | |
| # Remove from unscanned list | |
| if 'unscanned_pdfs' in st.session_state: | |
| if pdf_name in st.session_state.unscanned_pdfs: | |
| st.session_state.unscanned_pdfs.remove(pdf_name) | |
| else: | |
| st.error("Failed to save OCR result") | |
| except Exception as e: | |
| st.error(f"Error processing {pdf_name}: {e}") | |
| def process_batch_ocr(pdf_list): | |
| """Process multiple PDFs in batch""" | |
| if not claude: | |
| st.error("β Claude client not initialized. Check your ANTHROPIC_API_KEY secret.") | |
| return | |
| total_cost = 0.0 | |
| total_tokens = 0 | |
| success_count = 0 | |
| for pdf_name in pdf_list: | |
| st.markdown(f"### Processing: {pdf_name}") | |
| try: | |
| # Download PDF | |
| pdf_path, _ = download_with_cache(f"answers/{pdf_name}") | |
| if not pdf_path: | |
| st.error(f"β Failed to download {pdf_name}") | |
| continue | |
| # Get context | |
| context = get_context_for_ocr(pdf_name) | |
| # Convert to images | |
| images = pdf_to_images(pdf_path) | |
| if not images: | |
| st.error(f"β Failed to convert {pdf_name} to images") | |
| continue | |
| # OCR each page | |
| all_text = [] | |
| pdf_tokens = 0 | |
| for idx, image in enumerate(images): | |
| st.caption(f"Page {idx + 1}/{len(images)}...") | |
| page_text, tokens = ocr_with_claude(claude, image, context) | |
| if page_text: | |
| all_text.append(f"\n\n=== Page {idx + 1} ===\n\n{page_text}") | |
| pdf_tokens += tokens | |
| full_text = "\n".join(all_text) | |
| # Save to OCR_SCANNED | |
| success = save_ocr_scanned_to_hf(pdf_name, full_text) | |
| save_ocr_to_hf(pdf_name, full_text, pdf_tokens) | |
| if success: | |
| cost = pdf_tokens * 0.000003 | |
| total_cost += cost | |
| total_tokens += pdf_tokens | |
| success_count += 1 | |
| st.success(f"β {pdf_name} - Tokens: {pdf_tokens:,} | Cost: ${cost:.4f}") | |
| # Remove from unscanned list | |
| if 'unscanned_pdfs' in st.session_state: | |
| if pdf_name in st.session_state.unscanned_pdfs: | |
| st.session_state.unscanned_pdfs.remove(pdf_name) | |
| else: | |
| st.error(f"β Failed to save {pdf_name}") | |
| except Exception as e: | |
| st.error(f"β Error with {pdf_name}: {e}") | |
| st.markdown("---") | |
| # Final summary | |
| st.markdown("## π Batch Processing Complete") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("β Success", f"{success_count}/{len(pdf_list)}") | |
| with col2: | |
| st.metric("π Total Tokens", f"{total_tokens:,}") | |
| with col3: | |
| st.metric("π° Total Cost", f"${total_cost:.4f}") | |
| # ============================================================================ | |
| # INITIALIZE (Memory-efficient) | |
| # ============================================================================ | |
| # Only initialize what's absolutely needed at startup | |
| try: | |
| qdrant = get_qdrant_client() | |
| claude = get_claude_client() | |
| st.sidebar.success("β Ready") | |
| except Exception as e: | |
| st.sidebar.error("β οΈ Setup") | |
| st.error(f"β Initialization issue: {e}") | |
| st.info(""" | |
| **Please check:** | |
| - QDRANT_URL is set correctly | |
| - QDRANT_API_KEY is set correctly | |
| - ANTHROPIC_API_KEY is set correctly | |
| - HF_TOKEN is set correctly | |
| Go to your Streamlit Cloud settings β Secrets | |
| """) | |
| # Don't stop the app, allow user to see the error and fix it | |
| qdrant = None | |
| claude = None | |
| # ============================================================================ | |
| # SIDEBAR | |
| # ============================================================================ | |
| st.sidebar.title("π Math AI v5.0") | |
| # Show current model name without loading it | |
| current_model_name = st.session_state.embedding_model | |
| model_display_name = "Unknown" | |
| for key, value in EMBEDDING_MODELS.items(): | |
| if value["name"] == current_model_name: | |
| model_display_name = key | |
| break | |
| st.sidebar.info(f"**Model:** {model_display_name}") | |
| st.sidebar.caption(f"**Chunking:** {st.session_state.chunking_strategy}") | |
| st.sidebar.markdown("---") | |
| # Use lazy loading for metrics to prevent startup issues | |
| try: | |
| vector_count = get_vector_count(qdrant) | |
| st.sidebar.metric("π Vectors", f"{vector_count:,}") | |
| except Exception as e: | |
| st.sidebar.metric("π Vectors", "---") | |
| st.sidebar.caption("Click refresh to load") | |
| st.sidebar.markdown("---") | |
| # Lazy load OCR stats | |
| try: | |
| scanned_files_count = len(list_ocr_scanned_files()) | |
| ocr_count, ocr_saved = get_ocr_cache_stats_hf() | |
| st.sidebar.metric("π OCR Scanned", scanned_files_count) | |
| st.sidebar.metric("π° Total Spent", f"${ocr_saved:.2f}") | |
| st.sidebar.caption("π‘ OCR saved to OCR_SCANNED/") | |
| except Exception as e: | |
| st.sidebar.metric("π OCR Scanned", "---") | |
| st.sidebar.caption("Loading...") | |
| if st.sidebar.button("ποΈ Clear PDF"): | |
| clear_pdf_cache() | |
| st.rerun() | |
| st.sidebar.markdown("---") | |
| if st.sidebar.button("ποΈ DELETE VDB"): | |
| if st.sidebar.checkbox("β οΈ Confirm"): | |
| delete_vector_database(qdrant) | |
| st.sidebar.success("Deleted!") | |
| st.rerun() | |
| # ============================================================================ | |
| # TABS | |
| # ============================================================================ | |
| tab1, tab2, tab3, tab4 = st.tabs([ | |
| "π Data Preparation", | |
| "π’ Embedding Pipeline", | |
| "π Retrieval & Search", | |
| "π Statistics" | |
| ]) | |
| # ============================================================================ | |
| # TAB 1: DATA PREPARATION | |
| # ============================================================================ | |
| with tab1: | |
| st.title("π Data Preparation") | |
| prep_tabs = st.tabs(["π₯ Loading & Preview", "π Public Datasets", "ποΈ OCR Processing"]) | |
| with prep_tabs[0]: | |
| st.header("π₯ Data Loading & Preview") | |
| st.info("Upload PDFs to HuggingFace and preview them here. OCR_SCANNED folder contains processed answer sheets.") | |
| folder = st.selectbox("Folder:", ["books/", "exams/", "answers/", "OCR_SCANNED/"]) | |
| if st.button("π Scan", key="scan_folder_button"): | |
| with st.spinner(f"Scanning {folder}..."): | |
| if folder == "OCR_SCANNED/": | |
| files = list_ocr_scanned_files() | |
| else: | |
| files = list_dataset_files(folder) | |
| if files: | |
| st.session_state.scanned = files | |
| st.session_state.scanned_folder = folder | |
| st.success(f"β Found {len(files)} files in {folder}") | |
| else: | |
| st.warning(f"No files found in {folder}") | |
| st.session_state.scanned = [] | |
| if 'scanned' in st.session_state and st.session_state.scanned: | |
| st.subheader(f"π Files in {st.session_state.get('scanned_folder', folder)}") | |
| for file in st.session_state.scanned: | |
| name = file.split('/')[-1] | |
| with st.expander(f"π {name}"): | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.caption(f"**Path:** `{file}`") | |
| with col2: | |
| if st.button(f"ποΈ Preview", key=f"preview_{name}"): | |
| st.session_state[f'show_preview_{name}'] = True | |
| # Show preview if button was clicked | |
| if st.session_state.get(f'show_preview_{name}', False): | |
| try: | |
| if st.session_state.get('scanned_folder') == "OCR_SCANNED/": | |
| # Load from OCR_SCANNED folder | |
| text = load_ocr_scanned_from_hf(file) | |
| if text: | |
| st.success("β OCR file loaded successfully") | |
| st.text_area("Content:", text[:1000], height=200, key=f"content_{name}") | |
| # Show stats | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Characters", len(text)) | |
| with col2: | |
| st.metric("Words", len(text.split())) | |
| with col3: | |
| latex_count = text.count('\\') + text.count('$') | |
| st.metric("LaTeX", latex_count) | |
| else: | |
| st.error("β Could not load file") | |
| elif st.session_state.get('scanned_folder') == "answers/": | |
| # Check if OCR version exists | |
| if is_pdf_ocr_scanned(name): | |
| st.success("β This file has been OCR processed!") | |
| if st.button("View OCR Version", key=f"view_ocr_{name}"): | |
| text = load_ocr_scanned_from_hf(name) | |
| if text: | |
| st.text_area("OCR Version:", text[:500], height=150, key=f"ocr_{name}") | |
| else: | |
| st.warning("β³ This file needs OCR processing") | |
| # Show PDF preview | |
| path, cached = download_with_cache(file) | |
| if path: | |
| text = extract_text_from_pdf(path) | |
| if text: | |
| st.text_area("PDF Text:", text[:500], height=150, key=f"pdf_{name}") | |
| st.caption(f"{'π Downloaded' if not cached else 'πΎ From cache'}") | |
| else: | |
| # Regular PDF preview for books/exams | |
| path, cached = download_with_cache(file) | |
| if path: | |
| text = extract_text_from_pdf(path) | |
| if text: | |
| st.text_area("Text:", text[:500], height=150, key=f"text_{name}") | |
| st.caption(f"{'π Downloaded' if not cached else 'πΎ From cache'}") | |
| except Exception as e: | |
| st.error(f"Error previewing file: {e}") | |
| with prep_tabs[1]: | |
| st.header("π Public Dataset Preview") | |
| st.info("Preview sample problems from public math datasets before loading") | |
| # Select category | |
| category = st.selectbox("π Select Category:", list(PUBLIC_DATASETS.keys())) | |
| category_info = PUBLIC_DATASETS[category] | |
| st.caption(f"*{category_info['description']}*") | |
| # Select dataset within category | |
| dataset_names = [d['name'] for d in category_info['datasets']] | |
| selected_dataset_name = st.selectbox("π Select Dataset:", dataset_names) | |
| # Find the selected dataset info | |
| dataset_info = None | |
| for d in category_info['datasets']: | |
| if d['name'] == selected_dataset_name: | |
| dataset_info = d | |
| break | |
| if dataset_info: | |
| # Display dataset info | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.metric("Relevance", dataset_info['relevance']) | |
| st.caption(f"**Topics:** {dataset_info['topics']}") | |
| with col2: | |
| st.caption(f"**HF Path:** `{dataset_info['hf_path']}`") | |
| if dataset_info.get('subset'): | |
| st.caption(f"**Subset:** `{dataset_info['subset']}`") | |
| # Sample count selector | |
| sample_count = st.slider("Number of samples to preview:", 1, 10, 3) | |
| # Preview button | |
| if st.button("π Preview Samples", key=f"preview_{selected_dataset_name}"): | |
| with st.spinner("Loading samples..."): | |
| samples = preview_public_dataset(dataset_info, sample_count) | |
| if samples: | |
| st.success(f"β Loaded {len(samples)} samples") | |
| for i, sample in enumerate(samples, 1): | |
| with st.expander(f"π Sample {i}", expanded=(i==1)): | |
| st.markdown("**Problem:**") | |
| st.code(sample['problem'], language="text") | |
| st.markdown("**Solution:**") | |
| st.code(sample['solution'], language="text") | |
| else: | |
| st.error("Failed to load samples") | |
| st.markdown("---") | |
| st.info("π‘ **Tip:** After previewing, go to the Embedding Pipeline tab to load this dataset into your vector database") | |
| with prep_tabs[2]: | |
| st.header("ποΈ OCR Processing") | |
| st.info("π AI-powered OCR with context from books/ and exams/ folders. Results saved to OCR_SCANNED/ folder.") | |
| # Debug/Test Section | |
| with st.expander("π§ Debug & Test Connection"): | |
| st.subheader("Test HuggingFace Connection") | |
| if st.button("π§ͺ Test Connection"): | |
| try: | |
| hf_token = os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| st.error("β HF_TOKEN not found in environment") | |
| else: | |
| st.success("β HF_TOKEN found") | |
| # Test listing files | |
| all_files = list_repo_files( | |
| repo_id=DATASET_REPO, | |
| repo_type="dataset", | |
| token=hf_token | |
| ) | |
| st.success(f"β Connected to {DATASET_REPO}") | |
| st.caption(f"Total files in dataset: {len(all_files)}") | |
| # Show OCR_SCANNED files | |
| ocr_files = [f for f in all_files if f.startswith(OCR_SCANNED_FOLDER)] | |
| st.info(f"π Files in OCR_SCANNED/: {len(ocr_files)}") | |
| if ocr_files: | |
| st.write("Files found:") | |
| for f in ocr_files: | |
| st.code(f) | |
| # Show answers files | |
| answer_files = [f for f in all_files if f.startswith("answers/")] | |
| st.info(f"π Files in answers/: {len(answer_files)}") | |
| except Exception as e: | |
| st.error(f"β Connection test failed: {e}") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| st.markdown("---") | |
| # Display statistics | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| total_answers = len(list_dataset_files("answers/")) | |
| st.metric("π Total Answers", total_answers) | |
| with col2: | |
| scanned_count = len(list_ocr_scanned_files()) | |
| st.metric("β Scanned", scanned_count) | |
| with col3: | |
| ocr_count, ocr_cost = get_ocr_cache_stats_hf() | |
| st.metric("π° Total Cost", f"${ocr_cost:.2f}") | |
| st.markdown("---") | |
| # Scan for unscanned PDFs | |
| if st.button("π Scan for Unprocessed Files", use_container_width=True): | |
| with st.spinner("Scanning answers/ folder..."): | |
| unscanned = get_unscanned_pdfs() | |
| st.session_state.unscanned_pdfs = unscanned | |
| if unscanned: | |
| st.success(f"β Found {len(unscanned)} files that need OCR processing") | |
| else: | |
| st.info("π All files are already processed!") | |
| # Display unscanned files | |
| if 'unscanned_pdfs' in st.session_state and st.session_state.unscanned_pdfs: | |
| st.subheader(f"π Unprocessed Files ({len(st.session_state.unscanned_pdfs)})") | |
| for pdf_name in st.session_state.unscanned_pdfs: | |
| with st.expander(f"π {pdf_name}", expanded=False): | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.caption(f"Status: β³ Waiting for OCR") | |
| st.caption(f"Will be saved to: `{get_ocr_scanned_filename(pdf_name)}`") | |
| with col2: | |
| if st.button("π Process", key=f"ocr_{pdf_name}", use_container_width=True): | |
| process_single_pdf_ocr(pdf_name) | |
| st.markdown("---") | |
| # Batch processing | |
| st.subheader("β‘ Batch Processing") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| batch_size = st.number_input( | |
| "Number of files to process:", | |
| min_value=1, | |
| max_value=len(st.session_state.unscanned_pdfs), | |
| value=min(3, len(st.session_state.unscanned_pdfs)) | |
| ) | |
| with col2: | |
| st.metric("Estimated Cost", f"${batch_size * 0.05:.2f}") | |
| st.caption("~$0.05 per file average") | |
| if st.button("π Process Batch", type="primary", use_container_width=True): | |
| process_batch_ocr(st.session_state.unscanned_pdfs[:batch_size]) | |
| st.markdown("---") | |
| # Preview OCR_SCANNED files | |
| st.subheader("π OCR_SCANNED Files Preview") | |
| scanned_files = list_ocr_scanned_files() | |
| if scanned_files: | |
| st.success(f"β Found {len(scanned_files)} OCR scanned files") | |
| selected_file = st.selectbox( | |
| "Select file to preview:", | |
| scanned_files, | |
| format_func=lambda x: x.split('/')[-1] | |
| ) | |
| if st.button("ποΈ Preview OCR Result", use_container_width=True): | |
| with st.spinner("Loading OCR file..."): | |
| ocr_text = load_ocr_scanned_from_hf(selected_file) | |
| if ocr_text: | |
| st.success("β File loaded successfully") | |
| st.text_area("OCR Transcription:", ocr_text, height=400) | |
| # Quality indicators | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Characters", f"{len(ocr_text):,}") | |
| with col2: | |
| st.metric("Words", f"{len(ocr_text.split()):,}") | |
| with col3: | |
| latex_count = ocr_text.count('\\') + ocr_text.count('$') | |
| st.metric("LaTeX Elements", latex_count) | |
| else: | |
| st.error("β Failed to load OCR file") | |
| else: | |
| st.info("π No scanned files yet. Process some files above!") | |
| st.markdown("---") | |
| # Reset/Management options | |
| st.subheader("π§ Management") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("π Refresh Status", use_container_width=True): | |
| # Clear cached data | |
| if 'unscanned_pdfs' in st.session_state: | |
| del st.session_state.unscanned_pdfs | |
| st.rerun() | |
| with col2: | |
| with st.popover("β οΈ Reset OCR Cache"): | |
| st.warning("This will clear the metadata cache (JSON files) but NOT the OCR_SCANNED files.") | |
| st.caption("OCR_SCANNED files are permanent and must be deleted manually from HuggingFace.") | |
| if st.button("Confirm Reset Cache"): | |
| # This would clear the ocr_cache folder, not OCR_SCANNED | |
| st.info("Cache metadata reset (OCR_SCANNED files preserved)") | |
| st.rerun() | |
| # ============================================================================ | |
| # TAB 2: EMBEDDING PIPELINE | |
| # ============================================================================ | |
| with tab2: | |
| st.title("π’ Embedding Pipeline") | |
| st.header("βοΈ Configuration") | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| strategy = st.selectbox("Chunking:", list(CHUNKING_STRATEGIES.keys())) | |
| chunk_size = st.slider("Size:", 50, 800, st.session_state.chunk_size) | |
| if st.button("πΎ Save"): | |
| st.session_state.chunking_strategy = strategy | |
| st.session_state.chunk_size = chunk_size | |
| st.success("Saved!") | |
| with c2: | |
| model = st.selectbox("Model:", list(EMBEDDING_MODELS.keys())) | |
| if st.button("πΎ Save Model"): | |
| st.session_state.embedding_model = EMBEDDING_MODELS[model]['name'] | |
| st.success("Saved!") | |
| # ============================================================================ | |
| # TAB 3: RETRIEVAL | |
| # ============================================================================ | |
| with tab3: | |
| st.title("π Retrieval & Search") | |
| query = st.text_area("Query:", height=150) | |
| top_k = st.slider("Top K:", 3, 20, 5) | |
| if st.button("π SEARCH") and query: | |
| if not qdrant: | |
| st.error("β Qdrant client not initialized. Check your secrets.") | |
| st.stop() | |
| embedder = get_embedding_model(st.session_state.embedding_model) | |
| query_emb = embedder.encode(query) | |
| try: | |
| results = qdrant.search( | |
| collection_name=COLLECTION_NAME, | |
| query_vector=query_emb.tolist(), | |
| limit=top_k | |
| ) | |
| if results: | |
| st.success(f"Found {len(results)} results") | |
| for i, r in enumerate(results, 1): | |
| with st.expander(f"Result {i} ({r.score*100:.1f}%)"): | |
| st.text_area("Content:", r.payload['content'], height=150, key=f"r_{i}") | |
| else: | |
| st.warning("No results") | |
| except: | |
| st.error("Search failed") | |
| # ============================================================================ | |
| # TAB 4: STATISTICS | |
| # ============================================================================ | |
| with tab4: | |
| st.title("π Statistics") | |
| try: | |
| total = get_vector_count(qdrant) | |
| st.metric("Total Vectors", f"{total:,}") | |
| except: | |
| st.info("No data yet") |