DOCVISION / src /visual_cues.py
ChinnaVemareddy23's picture
Update src/visual_cues.py
a25ea49 verified
# import io
# import base64
# from typing import List, Dict, Tuple
# from PIL import Image
# from transformers import pipeline
# from src.config import LOGO_DETECTION_MODEL
# # --------------------------------------------------
# # MODEL INITIALIZATION (LOAD ONCE)
# # --------------------------------------------------
# # Object detection pipeline for logo / seal detection
# detector = pipeline(
# task="object-detection",
# model=LOGO_DETECTION_MODEL,
# device=-1 # CPU
# )
# # --------------------------------------------------
# # LOGO DETECTION
# # --------------------------------------------------
# def detect_logos_from_bytes(
# image_bytes: bytes,
# resize: Tuple[int, int] = (1024, 1024),
# max_logos: int = 3
# ) -> List[Dict[str, str | float]]:
# """
# Detect logos or visual emblems from raw image bytes.
# The function resizes the image for faster inference,
# detects logo regions, crops them, and returns the
# cropped logo images encoded in base64 along with
# confidence scores.
# Parameters
# ----------
# image_bytes : bytes
# Raw image data.
# resize : tuple[int, int], optional
# Maximum image size for inference (default: 1024x1024).
# max_logos : int, optional
# Maximum number of detected logos to return.
# Returns
# -------
# list[dict]
# List of detected logos with:
# - confidence: float
# - image_base64: str
# """
# # Load image from bytes
# image: Image.Image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# # Resize image for performance optimization
# image.thumbnail(resize)
# # Run object detection
# detections = detector(image)
# results: List[Dict[str, str | float]] = []
# # Process top detections only
# for det in detections[:max_logos]:
# box = det["box"]
# score: float = float(det["score"])
# xmin: int = int(box["xmin"])
# ymin: int = int(box["ymin"])
# xmax: int = int(box["xmax"])
# ymax: int = int(box["ymax"])
# # Crop detected logo region
# cropped = image.crop((xmin, ymin, xmax, ymax))
# # Convert cropped logo to base64
# buffer = io.BytesIO()
# cropped.save(buffer, format="PNG")
# results.append({
# "confidence": round(score, 3),
# "image_base64": base64.b64encode(buffer.getvalue()).decode()
# })
# return results
import io
import base64
from typing import List, Dict, Tuple
from PIL import Image
from transformers import pipeline
from src.config import LOGO_DETECTION_MODEL
# --------------------------------------------------
# MODEL INITIALIZATION (LOAD ONCE)
# --------------------------------------------------
detector = pipeline(
task="object-detection",
model=LOGO_DETECTION_MODEL,
device=-1 # CPU (HF Spaces safe)
)
# --------------------------------------------------
# LOGO DETECTION FUNCTION
# --------------------------------------------------
def detect_logos_from_bytes(
image_bytes: bytes,
resize: Tuple[int, int] = (1024, 1024),
max_logos: int = 4,
threshold: float = 0.2
) -> List[Dict[str, str | float]]:
"""
Detect logos or visual emblems from raw image bytes.
Returns cropped logo images (base64) with confidence scores.
Works consistently on local & Hugging Face Spaces.
"""
# -------------------------------
# Load image (deterministic)
# -------------------------------
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Deterministic resize (NO thumbnail)
image = image.resize(
(
min(image.width, resize[0]),
min(image.height, resize[1])
)
)
# -------------------------------
# Object detection (EXPLICIT threshold)
# -------------------------------
detections = detector(
image,
threshold=threshold
)
if not detections:
return []
# -------------------------------
# Sort by confidence (IMPORTANT)
# -------------------------------
detections = sorted(
detections,
key=lambda x: x["score"],
reverse=True
)
results: List[Dict[str, str | float]] = []
# -------------------------------
# Process top detections
# -------------------------------
for det in detections[:max_logos]:
box = det["box"]
score = float(det["score"])
xmin = max(0, int(box["xmin"]))
ymin = max(0, int(box["ymin"]))
xmax = min(image.width, int(box["xmax"]))
ymax = min(image.height, int(box["ymax"]))
# Safety check
if xmax <= xmin or ymax <= ymin:
continue
# Crop logo region
cropped = image.crop((xmin, ymin, xmax, ymax))
# Encode cropped logo to base64
buffer = io.BytesIO()
cropped.save(buffer, format="PNG")
results.append({
"confidence": round(score, 3),
"image_base64": base64.b64encode(buffer.getvalue()).decode("utf-8")
})
return results