OCR / main.py
sharshar1's picture
Upload main.py
aa2a92f verified
import os
import re
import tempfile
import traceback
import cv2
import numpy as np
import uvicorn
from io import BytesIO
from typing import List, Dict, Any
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
# -------------------- Writable env for PaddleX --------------------
# بنحاول نوجّه أي cache/resource paths لأماكن writable
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["HOME"] = os.environ.get("HOME", "/home/user") or "/home/user"
os.environ["XDG_CACHE_HOME"] = os.environ.get("XDG_CACHE_HOME", "/home/user/.cache") or "/home/user/.cache"
os.environ["PADDLEX_HOME"] = os.environ.get("PADDLEX_HOME", "/home/user/.paddlex") or "/home/user/.paddlex"
os.environ["TMPDIR"] = os.environ.get("TMPDIR", "/tmp") or "/tmp"
# دعم ملفات PDF
try:
from pdf2image import convert_from_bytes
PDF_AVAILABLE = True
except ImportError:
PDF_AVAILABLE = False
# -------------------- Global OCR Models --------------------
paddle_detector = None
paddle_recognizer = None
model_load_error = None
DETECTOR_CANDIDATES = [
"PP-OCRv5_mobile_det",
"PP-OCRv5_server_det",
]
RECOGNIZER_CANDIDATES = [
"arabic_PP-OCRv5_mobile_rec",
"PP-OCRv5_mobile_rec",
]
app = FastAPI(title="OCR Scan Vision API", version="1.2.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------------------- Arabic Text Cleaning --------------------
def smart_clean_arabic_text(text: str) -> str:
if not text:
return ""
text = re.sub(r"[:\-_\/]", " ", text)
text = re.sub(r"([اأإآ])\s+([ل])", r"\1\2", text)
text = re.sub(r"([اأإآدذرزو])\s+([\u0600-\u06FF])", r"\1\2", text)
text = re.sub(r"[«»“”‘’]", "", text)
text = re.sub(r"\s*([.,;?!])\s*", r"\1 ", text)
text = re.sub(r"\s*([()])\s*", r" \1 ", text)
text = re.sub(r"\s*-\s*", "-", text)
text = re.sub(r"[\u064B-\u065F]", "", text)
text = re.sub(r"[^\u0600-\u06FFA-Za-z0-9\s]", "", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
# -------------------- Helpers --------------------
def get_versions_info() -> Dict[str, str]:
info = {}
try:
import paddle
info["paddle"] = getattr(paddle, "__version__", "unknown")
except Exception as e:
info["paddle"] = f"not available: {e}"
# مهم: ما نعملش import paddlex هنا
# لأن مجرد import كان بيفتح path fonts وبيكسر مع permission issues
info["paddlex"] = "lazy import only"
try:
info["numpy"] = np.__version__
except Exception as e:
info["numpy"] = f"not available: {e}"
try:
import cv2 as _cv2
info["opencv"] = getattr(_cv2, "__version__", "unknown")
except Exception as e:
info["opencv"] = f"not available: {e}"
return info
def validate_image_array(img: np.ndarray) -> None:
if img is None:
raise ValueError("Image is None")
if not isinstance(img, np.ndarray):
raise ValueError("Image is not a numpy array")
if img.size == 0:
raise ValueError("Image array is empty")
if len(img.shape) != 3 or img.shape[2] != 3:
raise ValueError(f"Expected BGR image with 3 channels, got shape={img.shape}")
def ensure_runtime_dirs():
paths = [
os.environ.get("HOME", "/home/user"),
os.environ.get("XDG_CACHE_HOME", "/home/user/.cache"),
os.environ.get("PADDLEX_HOME", "/home/user/.paddlex"),
"/tmp",
]
for p in paths:
if p:
os.makedirs(p, exist_ok=True)
def load_first_available_model(create_model_fn, role: str, candidates: List[str]):
errors = []
for model_name in candidates:
try:
print(f"Loading {role} model: {model_name}")
model = create_model_fn(model_name)
print(f"{role.capitalize()} model loaded successfully: {model_name}")
return model, model_name
except Exception as e:
traceback.print_exc()
errors.append({
"model": model_name,
"error": str(e),
})
print(f"Failed to load {role} model {model_name}: {e}")
raise RuntimeError(f"All {role} model candidates failed: {errors}")
def save_temp_image(img: np.ndarray, suffix: str = ".png") -> str:
validate_image_array(img)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
tmp_path = tmp.name
tmp.close()
if not cv2.imwrite(tmp_path, img):
raise RuntimeError(f"Failed to write temporary image: {tmp_path}")
return tmp_path
def get_models():
global paddle_detector, paddle_recognizer, model_load_error
if model_load_error is not None:
raise HTTPException(status_code=500, detail=model_load_error)
if paddle_detector is None or paddle_recognizer is None:
try:
ensure_runtime_dirs()
print("Loading PaddleX OCR models...")
import paddle
from paddlex import create_model
try:
paddle.set_num_threads(1)
except Exception as thread_error:
print(f"Warning: failed to set Paddle threads to 1: {thread_error}")
paddle_detector, detector_name = load_first_available_model(
create_model, "detector", DETECTOR_CANDIDATES
)
paddle_recognizer, recognizer_name = load_first_available_model(
create_model, "recognizer", RECOGNIZER_CANDIDATES
)
print(
"Models loaded successfully.",
{
"detector": detector_name,
"recognizer": recognizer_name,
},
)
except Exception as e:
traceback.print_exc()
model_load_error = {
"message": "OCR models failed to load",
"error": str(e),
"hint": "Likely PaddleX resource/font permission, thread configuration, or runtime compatibility issue.",
"versions": get_versions_info(),
}
raise HTTPException(status_code=500, detail=model_load_error)
return paddle_detector, paddle_recognizer
def safe_result_get(result: Any, key: str, default: Any = None) -> Any:
if result is None:
return default
if isinstance(result, dict):
return result.get(key, default)
if hasattr(result, key):
try:
value = getattr(result, key)
return default if value is None else value
except Exception:
pass
try:
value = result[key]
return default if value is None else value
except Exception:
return default
def normalize_items(value: Any) -> List[Any]:
if value is None:
return []
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, (list, tuple)):
return list(value)
return [value]
def describe_result_shape(result: Any) -> Dict[str, Any]:
description = {
"type": type(result).__name__,
}
if isinstance(result, dict):
description["keys"] = list(result.keys())[:10]
return description
try:
description["attrs"] = [name for name in dir(result) if not name.startswith("_")][:10]
except Exception:
description["attrs"] = []
return description
def extract_boxes_from_detection_result(result: Any) -> List[Any]:
if result is None:
return []
if isinstance(result, (list, tuple)):
collected_boxes = []
for item in result:
collected_boxes.extend(extract_boxes_from_detection_result(item))
return collected_boxes
for key in ["dt_polys", "polys", "boxes", "det_polys", "dt_boxes"]:
boxes = safe_result_get(result, key)
normalized = normalize_items(boxes)
if normalized:
return normalized
for key in ["res", "result", "data", "output"]:
nested = safe_result_get(result, key)
if nested is None or nested is result:
continue
nested_boxes = extract_boxes_from_detection_result(nested)
if nested_boxes:
return nested_boxes
return []
def extract_recognition_payload(result: Any) -> Dict[str, Any]:
if result is None:
return {
"raw_text": "",
"score": 0.0,
}
if isinstance(result, (list, tuple)):
best_payload = {
"raw_text": "",
"score": 0.0,
}
for item in result:
payload = extract_recognition_payload(item)
if payload["score"] > best_payload["score"] or (
payload["raw_text"] and not best_payload["raw_text"]
):
best_payload = payload
return best_payload
for key in ["res", "result", "data", "output"]:
nested = safe_result_get(result, key)
if nested is not None and nested is not result:
nested_payload = extract_recognition_payload(nested)
if nested_payload["raw_text"]:
return nested_payload
texts = None
for key in ["rec_text", "text", "label", "transcription", "ocr_text", "rec_texts"]:
texts = safe_result_get(result, key)
if texts not in [None, ""]:
break
scores = None
for key in ["rec_score", "score", "confidence", "probability", "rec_scores"]:
scores = safe_result_get(result, key)
if scores is not None:
break
normalized_texts = normalize_items(texts)
normalized_scores = normalize_items(scores)
raw_text = ""
if normalized_texts:
raw_text = " ".join(str(item).strip() for item in normalized_texts if str(item).strip())
elif isinstance(texts, str):
raw_text = texts.strip()
score = 0.0
if normalized_scores:
try:
score = max(float(item) for item in normalized_scores if item is not None)
except Exception:
score = 0.0
else:
try:
if scores is not None:
score = float(scores)
except Exception:
score = 0.0
return {
"raw_text": raw_text,
"score": score,
}
def run_detector(detector, img: np.ndarray) -> List[Dict[str, Any]]:
temp_path = save_temp_image(img, suffix=".png")
try:
results = detector.predict(temp_path)
return list(results)
except Exception as path_error:
print(f"Detector path inference failed, retrying with ndarray input: {path_error}")
traceback.print_exc()
try:
results = detector.predict(img)
return list(results)
except Exception as direct_error:
traceback.print_exc()
raise RuntimeError(
f"Detector inference failed. path_error={path_error}; direct_error={direct_error}"
)
finally:
try:
os.remove(temp_path)
except OSError:
pass
def run_recognizer(recognizer, roi: np.ndarray) -> Dict[str, Any]:
try:
rec_gen = recognizer.predict(roi)
return next(rec_gen, {})
except Exception as direct_error:
print(f"Recognizer ndarray inference failed, retrying with temp file: {direct_error}")
traceback.print_exc()
temp_path = save_temp_image(roi, suffix=".png")
try:
rec_gen = recognizer.predict(temp_path)
return next(rec_gen, {})
except Exception as path_error:
traceback.print_exc()
raise RuntimeError(
f"Recognizer inference failed. direct_error={direct_error}; path_error={path_error}"
)
finally:
try:
os.remove(temp_path)
except OSError:
pass
def process_image(img: np.ndarray, detector, recognizer, min_conf: float) -> List[Dict]:
validate_image_array(img)
h_img, w_img = img.shape[:2]
# Detection
try:
results = run_detector(detector, img)
except Exception as e:
raise RuntimeError(f"Detector inference failed: {str(e)}")
print(f"Detector returned {len(results)} result item(s)")
if results:
print("First detector result shape:", describe_result_shape(results[0]))
all_rois = []
all_bboxes = []
try:
for result in results:
boxes = extract_boxes_from_detection_result(result)
for box in boxes:
pts = np.array(box, dtype=np.int32)
if pts.size == 0:
continue
x, y, w, h = cv2.boundingRect(pts)
x1 = max(int(x), 0)
y1 = max(int(y), 0)
x2 = min(int(x + w), w_img)
y2 = min(int(y + h), h_img)
if x2 > x1 and y2 > y1:
roi = img[y1:y2, x1:x2]
if roi is not None and roi.size > 0:
all_rois.append(roi)
all_bboxes.append([x1, y1, x2, y2])
except Exception as e:
raise RuntimeError(f"Detection post-processing failed: {str(e)}")
if not all_rois:
print("Detector produced no usable ROIs. Falling back to full-image recognition.")
all_rois.append(img)
all_bboxes.append([0, 0, w_img, h_img])
ocr_results = []
for i, roi in enumerate(all_rois):
try:
rec = run_recognizer(recognizer, roi)
payload = extract_recognition_payload(rec)
raw_text = payload["raw_text"]
score = float(payload["score"])
text = smart_clean_arabic_text(raw_text)
except Exception as e:
raw_text = ""
text = ""
score = 0.0
print(f"Recognition failed for ROI #{i + 1}: {e}")
if i == 0:
print(
f"Recognition sample #{i + 1}: raw_text_len={len(raw_text)}, score={score}, "
f"result_shape={describe_result_shape(rec) if 'rec' in locals() else 'n/a'}"
)
if score >= min_conf and text:
ocr_results.append({
"box_id": i + 1,
"text": text,
"confidence": round(score, 4),
"bbox": all_bboxes[i]
})
# من أعلى لأسفل، ومن اليمين لليسار
ocr_results.sort(
key=lambda x: (
x["bbox"][1],
-x["bbox"][0]
)
)
print(
f"OCR produced {len(ocr_results)} accepted item(s) from {len(all_rois)} ROI(s) "
f"with min_conf={min_conf}"
)
return ocr_results
def read_uploaded_image(contents: bytes) -> np.ndarray:
try:
pil_img = Image.open(BytesIO(contents)).convert("RGB")
img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
validate_image_array(img)
return img
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}")
# -------------------- Events --------------------
@app.on_event("startup")
def startup_event():
ensure_runtime_dirs()
print("Application startup complete.")
print("Versions:", get_versions_info())
print(
"Thread env:",
{
"OMP_NUM_THREADS": os.environ.get("OMP_NUM_THREADS"),
"OPENBLAS_NUM_THREADS": os.environ.get("OPENBLAS_NUM_THREADS"),
"MKL_NUM_THREADS": os.environ.get("MKL_NUM_THREADS"),
},
)
try:
get_models()
print("Startup warmup completed successfully.")
except Exception as e:
print(f"Startup warmup failed: {e}")
traceback.print_exc()
# -------------------- Endpoints --------------------
@app.get("/")
def root():
return {
"name": "OCR Scan Vision API",
"status": "ok",
"pdf_support": PDF_AVAILABLE,
"versions": get_versions_info(),
}
@app.get("/health")
def health():
return {
"status": "healthy",
"versions": get_versions_info(),
}
@app.get("/versions")
def versions():
return get_versions_info()
@app.api_route("/load-models", methods=["GET", "POST"])
def load_models():
detector, recognizer = get_models()
return {
"message": "Models loaded successfully",
"detector_loaded": detector is not None,
"recognizer_loaded": recognizer is not None,
"versions": get_versions_info(),
}
@app.post("/ocr")
async def ocr_image(
file: UploadFile = File(...),
min_conf: float = Query(default=0.0, ge=0.0, le=1.0),
):
contents = await file.read()
img = read_uploaded_image(contents)
detector, recognizer = get_models()
try:
ocr_results = process_image(img, detector, recognizer, min_conf)
except Exception as e:
traceback.print_exc()
raise HTTPException(
status_code=500,
detail={
"message": "OCR inference failed",
"error": str(e),
"versions": get_versions_info(),
"filename": file.filename,
"min_conf": min_conf,
}
)
full_text = "\n".join([r["text"] for r in ocr_results])
return {
"items": ocr_results,
"text": full_text,
"total_boxes": len(ocr_results),
"filename": file.filename,
}
@app.post("/ocr-pdf")
async def ocr_pdf(
file: UploadFile = File(...),
dpi: int = Query(default=300, ge=72, le=600),
min_conf: float = Query(default=0.0, ge=0.0, le=1.0),
):
if not PDF_AVAILABLE:
raise HTTPException(
status_code=500,
detail="PDF support not available. pdf2image or poppler may be missing."
)
try:
contents = await file.read()
pages = convert_from_bytes(contents, dpi=dpi)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid PDF file: {str(e)}")
detector, recognizer = get_models()
all_results = []
all_text = []
for page_num, pil_img in enumerate(pages, start=1):
try:
img = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2BGR)
validate_image_array(img)
page_results = process_image(img, detector, recognizer, min_conf)
for item in page_results:
item["page"] = page_num
all_results.extend(page_results)
page_text = "\n".join([r["text"] for r in page_results])
if page_text:
all_text.append(f"--- Page {page_num} ---\n{page_text}")
except Exception as e:
traceback.print_exc()
raise HTTPException(
status_code=500,
detail={
"message": f"OCR failed on PDF page {page_num}",
"error": str(e),
"versions": get_versions_info(),
"filename": file.filename,
"dpi": dpi,
"min_conf": min_conf,
}
)
return {
"pages": len(pages),
"items": all_results,
"text": "\n\n".join(all_text),
"total_boxes": len(all_results),
"filename": file.filename,
}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)