Zenctrl-Inpaint / florence_sam /detect_and_segment.py
salso's picture
Upload 28 files
545e508 verified
# detect_and_segment.py
import torch
import supervision as sv
from typing import List, Tuple, Optional
# ==== 1. One-time global model loading =====================================
from .utils.florence import (
load_florence_model,
run_florence_inference,
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
)
from .utils.sam import load_sam_image_model, run_sam_inference
from PIL import Image, ImageDraw, ImageColor
import numpy as np
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load models once – they stay in memory for repeated calls
FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
# quick annotators
COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2']
COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS)
BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
LABEL_ANNOTATOR = sv.LabelAnnotator(
color=COLOR_PALETTE,
color_lookup=sv.ColorLookup.INDEX,
text_position=sv.Position.CENTER_OF_MASS,
text_color=sv.Color.from_hex("#000000"),
border_radius=5,
)
MASK_ANNOTATOR = sv.MaskAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
# ==== 2. Inference function ===============================================
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def detect_and_segment(
image : Image.Image,
text_prompts : str | List[str],
return_image : bool = True,
) -> Tuple[sv.Detections, Optional[Image.Image]]:
"""
Run Florence-2 open-vocabulary detection + SAM2 mask refinement on a PIL image.
Parameters
----------
image : PIL.Image
Input image in RGB.
text_prompts : str | List[str]
Single prompt or comma-separated list (e.g. "dog, tail, leash").
return_image : bool
If True, also returns an annotated PIL image.
Returns
-------
detections : sv.Detections
Supervision object with xyxy, mask, class_id, etc.
annotated : PIL.Image | None
Annotated image (None if return_image=False)
"""
# Normalize prompt list
if isinstance(text_prompts, str):
prompts = [p.strip() for p in text_prompts.split(",") if p.strip()]
else:
prompts = [p.strip() for p in text_prompts]
if len(prompts) == 0:
raise ValueError("Empty prompt list given.")
# Collect detections from each prompt
det_list: list[sv.Detections] = []
for p in prompts:
_, result = run_florence_inference(
model = FLORENCE_MODEL,
processor = FLORENCE_PROC,
device = DEVICE,
image = image,
task = FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text = p,
)
det = sv.Detections.from_lmm(
lmm = sv.LMM.FLORENCE_2,
result = result,
resolution_wh = image.size,
)
det = run_sam_inference(SAM_IMAGE_MODEL, image, det) # SAM2 refinement
det_list.append(det)
detections = sv.Detections.merge(det_list)
annotated_img = None
if return_image:
annotated_img = image.copy()
annotated_img = MASK_ANNOTATOR.annotate(annotated_img, detections)
annotated_img = BOX_ANNOTATOR.annotate(annotated_img, detections)
annotated_img = LABEL_ANNOTATOR.annotate(annotated_img, detections)
return detections, annotated_img
def fill_detected_bboxes(
image: Image.Image,
text: str,
inflate_pct: float = 0.10,
fill_color: str | tuple[int, int, int] = "#00FF00",
):
"""
Detect objects matching `text`, inflate each bounding-box by `inflate_pct`,
fill the area with `fill_color`, and return the resulting image.
Parameters
----------
image : PIL.Image
Input image (RGB).
text : str
Comma-separated prompt(s) for open-vocabulary detection.
inflate_pct : float, default 0.10
Extra margin per side (0.10 = +10 % width & height).
fill_color : str | tuple, default "#00FF00"
Solid color used to fill each inflated bbox (hex or RGB tuple).
Returns
-------
filled_img : PIL.Image
Image with each detected (inflated) box filled.
detections : sv.Detections
Original detection object from `detect_and_segment`.
"""
# run Florence2 + SAM2 pipeline (your helper from earlier)
detections, _ = detect_and_segment(image, text)
w, h = image.size
filled_img = image.copy()
draw = ImageDraw.Draw(filled_img)
fill_rgb = ImageColor.getrgb(fill_color) if isinstance(fill_color, str) else fill_color
for box in detections.xyxy:
# xyxy is numpy array β†’ cast to float for math
x1, y1, x2, y2 = box.astype(float)
dw, dh = (x2 - x1) * inflate_pct, (y2 - y1) * inflate_pct
x1_i = max(0, x1 - dw)
y1_i = max(0, y1 - dh)
x2_i = min(w, x2 + dw)
y2_i = min(h, y2 + dh)
draw.rectangle([x1_i, y1_i, x2_i, y2_i], fill=fill_rgb)
return filled_img, detections