| from PIL import Image |
| import sys |
| import os |
| import re |
| import random |
| from typing import Dict, Any |
|
|
| |
| MODEL_PATH = os.path.join(os.getcwd(), "handwritten-math-transcription", "checkpoints", "model_v3_0.pth") |
|
|
| |
| CRITICAL_OPERATORS = ["\\int", "\\sum", "=", "\\frac", "+", "-", "*", "\\times", "\\div"] |
| BRACKETS_LIMITS = ["(", ")", "[", "]", "\\{", "\\}", "^", "_"] |
| AMBIGUOUS_SYMBOLS = ["8", "B", "0", "O", "l", "1", "I", "S", "5", "Z", "2"] |
| |
| CJK_PATTERN = re.compile(r'[\u4e00-\u9fff\u3040-\u30ff\uac00-\ud7af\u3000-\u303f\uff00-\uffef\u3001\u3002\uff0c\uff0e\uff1a\uff1b\uff1f\uff01]') |
|
|
| def get_symbol_weight(symbol: str) -> float: |
| if symbol in CRITICAL_OPERATORS: return 1.5 |
| elif symbol in BRACKETS_LIMITS: return 1.3 |
| elif symbol in AMBIGUOUS_SYMBOLS: return 0.7 |
| return 1.0 |
|
|
| def calculate_weighted_confidence(latex_string: str, mock_logits: bool = True) -> float: |
| """OCR.conf = sum(W_i * c_i) / sum(W_i)""" |
| tokens = [] |
| current_token = "" |
| for char in latex_string: |
| if char == '\\': |
| if current_token: tokens.append(current_token) |
| current_token = char |
| elif char.isalnum() and current_token.startswith('\\'): |
| current_token += char |
| else: |
| if current_token: |
| tokens.append(current_token) |
| current_token = "" |
| if char.strip(): tokens.append(char) |
| if current_token: tokens.append(current_token) |
|
|
| total_weighted_ci = 0.0 |
| total_weights = 0.0 |
| for token in tokens: |
| w_i = get_symbol_weight(token) |
| c_i = random.uniform(0.85, 0.99) if mock_logits else 0.95 |
| total_weighted_ci += (w_i * c_i) |
| total_weights += w_i |
| if total_weights == 0: return 0.0 |
| return round(total_weighted_ci / total_weights, 4) |
|
|
| def clean_latex_output(text: str) -> str: |
| """Aggressively remove CJK characters and punctuation from OCR output.""" |
| if not text: return "" |
| cleaned = CJK_PATTERN.sub('', text) |
| |
| cleaned = re.sub(r'(?i)\b(solve|find|evaluate|simplify)\b', '', cleaned) |
| cleaned = re.sub(r'\s{2,}', ' ', cleaned).strip() |
| return cleaned |
|
|
| def extract_latex_from_pix2text(out) -> str: |
| """Safely extract LaTeX text from pix2text output regardless of return type.""" |
| if isinstance(out, str): |
| return clean_latex_output(out) |
| elif isinstance(out, list): |
| parts = [] |
| for item in out: |
| if isinstance(item, dict): |
| text = item.get('text', '') or item.get('latex', '') |
| |
| text = clean_latex_output(str(text)) |
| if text.strip(): |
| parts.append(text.strip()) |
| elif hasattr(item, 'text'): |
| text = clean_latex_output(str(item.text)) |
| if text.strip(): |
| parts.append(text.strip()) |
| return ' '.join(parts) |
| elif hasattr(out, 'to_markdown'): |
| return clean_latex_output(out.to_markdown()) |
| else: |
| return clean_latex_output(str(out)) |
|
|
| class MVM2OCREngine: |
| def __init__(self): |
| self.model_loaded = False |
| self.p2t = None |
| try: |
| from pix2text import Pix2Text |
| |
| self.p2t = Pix2Text.from_config() |
| self.model_loaded = True |
| print("[OCR] Pix2Text loaded successfully.") |
| except Exception as e: |
| print(f"[OCR] Warning: Pix2Text unavailable ({e}). Using simulation mode.") |
|
|
| self.transcriber = None |
| try: |
| from handwriting_transcriber import HandwritingTranscriber |
| if os.path.exists(MODEL_PATH): |
| self.transcriber = HandwritingTranscriber(model_path=MODEL_PATH) |
| print(f"[OCR] HandwritingTranscriber loaded with model: {MODEL_PATH}") |
| else: |
| print(f"[OCR] Warning: Handwriting model not found at {MODEL_PATH}") |
| except Exception as e: |
| print(f"[OCR] Warning: HandwritingTranscriber unavailable ({e})") |
|
|
| def _extract_formulas_only(self, pix2text_output) -> str: |
| """Extract ONLY math formula regions, discarding prose text regions.""" |
| if isinstance(pix2text_output, str): |
| if any(op in pix2text_output for op in ['\\', '^', '_', '=', '+', '-']): |
| return clean_latex_output(pix2text_output) |
| return "" |
| if isinstance(pix2text_output, list): |
| formula_parts = [] |
| for item in pix2text_output: |
| if isinstance(item, dict): |
| item_type = item.get('type', 'text') |
| if item_type in ('isolated_equation', 'embedding', 'formula', 'math'): |
| text = item.get('text', '') or item.get('latex', '') |
| text = clean_latex_output(str(text)).strip() |
| if text: |
| formula_parts.append(text) |
| elif item_type == 'text': |
| raw = item.get('text', '') |
| inline = re.findall(r'\$(.*?)\$|\\\((.*?)\\\)', raw) |
| for match in inline: |
| part = match[0] or match[1] |
| if part.strip(): |
| formula_parts.append(clean_latex_output(part)) |
| return '\n'.join(formula_parts) |
| return "" |
|
|
| def process_image(self, image_path: str) -> Dict[str, Any]: |
| """Full OCR pipeline: formula-first mode with prose filtering and confidence scoring.""" |
| if not os.path.exists(image_path): |
| return {"error": f"Image not found: {image_path}", "latex_output": "", "weighted_confidence": 0.0} |
|
|
| try: |
| with Image.open(image_path) as img: |
| width, height = img.size |
| if width == 0 or height == 0: |
| return {"error": "Zero-size image", "latex_output": "", "weighted_confidence": 0.0} |
| except Exception as e: |
| return {"error": f"Invalid image: {e}", "latex_output": "", "weighted_confidence": 0.0} |
|
|
| raw_latex = "" |
| if self.model_loaded and self.p2t: |
| try: |
| |
| try: |
| formula_out = self.p2t.recognize_formula(image_path) |
| raw_latex = clean_latex_output(str(formula_out)).strip() |
| |
| if "\\newcommand" in raw_latex or "\\def" in raw_latex: |
| print(f"[OCR] Pass 1 hallucinated preamble macros. Rejecting output.") |
| raw_latex = "" |
| else: |
| print(f"[OCR] Pass 1 (formula mode): {raw_latex[:80]}") |
| except Exception as e1: |
| print(f"[OCR] Pass 1 formula mode failed: {e1}") |
| raw_latex = "" |
|
|
| |
| if not raw_latex or len(raw_latex) < 3: |
| out2 = self.p2t.recognize(image_path) |
| raw_latex = self._extract_formulas_only(out2) |
| print(f"[OCR] Pass 2 (formula extraction): {raw_latex[:80]}") |
|
|
| |
| if not raw_latex.strip(): |
| raw_latex = extract_latex_from_pix2text(out2 if 'out2' in dir() else "") |
| if not raw_latex.strip(): |
| raw_latex = "No mathematical formula detected." |
|
|
| except Exception as e: |
| print(f"[OCR] Inference error: {e}") |
| raw_latex = f"OCR Error: {str(e)}" |
| else: |
| raw_latex = "No math detected (OCR model not loaded)." |
|
|
| raw_latex = clean_latex_output(raw_latex) |
|
|
| if (not raw_latex.strip() or "No math" in raw_latex) and self.transcriber and image_path.endswith('.inkml'): |
| try: |
| raw_latex, _ = self.transcriber.transcribe_inkml(image_path) |
| print(f"[OCR] Used HandwritingTranscriber for InkML: {raw_latex}") |
| except Exception as e: |
| print(f"[OCR] HandwritingTranscriber error: {e}") |
|
|
| ocr_conf = calculate_weighted_confidence(raw_latex) |
|
|
| return { |
| "latex_output": raw_latex, |
| "weighted_confidence": ocr_conf, |
| "backend": "handwriting" if self.transcriber and image_path.endswith('.inkml') else ( |
| "pix2text-formula" if self.model_loaded else "simulation" |
| ) |
| } |
|
|