Zenctrl-Inpaint / florence_sam /process_batch.py
salso's picture
Upload 28 files
545e508 verified
#!/usr/bin/env python
# furniture_bbox_to_files.py ────────────────────────────────────────
# Florence-2 + SAM-2 batch processor with retries *and* file-based images
# --------------------------------------------------------------------
import os, json, random, time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List
import torch, supervision as sv
from PIL import Image, ImageDraw, ImageColor, ImageOps
from tqdm.auto import tqdm
from datasets import load_dataset, Image as HFImage, disable_progress_bar
# ───── global models ────────────────────────────────────────────────
from utils.florence import (
load_florence_model, run_florence_inference,
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
)
from utils.sam import load_sam_image_model, run_sam_inference
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
# annotators
_PALETTE = sv.ColorPalette.from_hex(
['#FF1493','#00BFFF','#FF6347','#FFD700','#32CD32','#8A2BE2'])
BOX_ANN = sv.BoxAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX)
MASK_ANN = sv.MaskAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX)
LBL_ANN = sv.LabelAnnotator(
color=_PALETTE, color_lookup=sv.ColorLookup.INDEX,
text_position=sv.Position.CENTER_OF_MASS,
text_color=sv.Color.from_hex("#000"), border_radius=5)
# ───── config ───────────────────────────────────────────────────────
os.environ["TOKENIZERS_PARALLELISM"] = "false"
disable_progress_bar()
DATASET_NAME = "fotographerai/furniture_captioned_segment_prompt"
SPLIT = "train"
IMAGE_COL = "img2"
PROMPT_COL = "segmenting_prompt"
INFLATE_RANGE = (0.01, 0.05)
FILL_COLOR = "#00FF00"
TARGET_SIDE = 1500
QA_DIR = Path("bbox_review_recaptioned")
GREEN_DIR = QA_DIR / "green"; GREEN_DIR.mkdir(parents=True, exist_ok=True)
ANNO_DIR = QA_DIR / "anno"; ANNO_DIR.mkdir(parents=True, exist_ok=True)
JSON_DIR = QA_DIR / "json"; JSON_DIR.mkdir(parents=True, exist_ok=True)
MAX_WORKERS = 100
MAX_RETRIES = 5
RETRY_SLEEP = .3
FAILED_LOG = QA_DIR / "failed_rows.jsonl"
PROMPT_MAP: dict[str,str] = {} # optional overrides
# ───── helpers ──────────────────────────────────────────────────────
def make_square(img: Image.Image, side: int = TARGET_SIDE) -> Image.Image:
img = ImageOps.contain(img, (side, side))
pad_w, pad_h = side - img.width, side - img.height
return ImageOps.expand(img, border=(pad_w//2, pad_h//2,
pad_w - pad_w//2, pad_h - pad_h//2),
fill=img.getpixel((0,0)))
def img_to_file(img: Image.Image, fname: str, folder: Path) -> dict:
path = folder / f"{fname}.png"
if not path.exists():
img.save(path)
return {"path": str(path), "bytes": None}
# ───── core functions ───────────────────────────────────────────────
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def detect_and_segment(img: Image.Image, prompts: str|List[str]) -> sv.Detections:
if isinstance(prompts, str):
prompts = [p.strip() for p in prompts.split(",") if p.strip()]
all_dets = []
for p in prompts:
_, res = run_florence_inference(
model=FLORENCE_MODEL, processor=FLORENCE_PROC, device=DEVICE,
image=img, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=p)
d = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, res, img.size)
all_dets.append(run_sam_inference(SAM_IMAGE_MODEL, img, d))
return sv.Detections.merge(all_dets)
def fill_detected_bboxes(img: Image.Image, prompt: str,
inflate_pct: float) -> tuple[Image.Image, sv.Detections]:
dets = detect_and_segment(img, prompt)
filled = img.copy()
draw = ImageDraw.Draw(filled)
rgb = ImageColor.getrgb(FILL_COLOR)
w,h = img.size
for box in dets.xyxy:
x1,y1,x2,y2 = box.astype(float)
dw,dh = (x2-x1)*inflate_pct, (y2-y1)*inflate_pct
draw.rectangle([max(0,x1-dw), max(0,y1-dh),
min(w,x2+dw), min(h,y2+dh)], fill=rgb)
return filled, dets
# ───── threaded worker ──────────────────────────────────────────────
def process_row(idx: int, sample):
prompt = PROMPT_MAP.get(sample[PROMPT_COL],
sample[PROMPT_COL].split(",",1)[0].strip())
img_sq = make_square(sample[IMAGE_COL].convert("RGB"))
for attempt in range(1, MAX_RETRIES+1):
try:
filled, dets = fill_detected_bboxes(
img_sq, prompt, inflate_pct=random.uniform(*INFLATE_RANGE))
if len(dets.xyxy) == 0:
raise ValueError("no detections")
sid = f"{idx:06d}"
json_p = JSON_DIR / f"{sid}_bbox.json"
json_p.write_text(json.dumps({"xyxy": dets.xyxy.tolist()}))
anno = img_sq.copy()
for ann in (MASK_ANN, BOX_ANN, LABEL_ANN):
anno = ann.annotate(anno, dets)
return ("ok",
img_to_file(filled, sid, GREEN_DIR),
img_to_file(anno, sid, ANNO_DIR),
json_p.read_text())
except Exception as e:
if attempt < MAX_RETRIES:
time.sleep(RETRY_SLEEP)
else:
return ("fail", str(e))
# ───── run batch ────────────────────────────────────────────────────
ds = load_dataset(DATASET_NAME, split=SPLIT, streaming=False)
N = len(ds)
print("Rows:", N)
filled_col, anno_col, json_col = [None]*N, [None]*N, [None]*N
fails = 0
with ThreadPoolExecutor(MAX_WORKERS) as pool:
fut2idx = {pool.submit(process_row, i, ds[i]): i for i in range(N)}
for fut in tqdm(as_completed(fut2idx), total=N, desc="Florence+SAM"):
idx = fut2idx[fut]
status, *data = fut.result()
if status == "ok":
filled_col[idx], anno_col[idx], json_col[idx] = data
else:
fails += 1
FAILED_LOG.write_text(json.dumps({"idx": idx, "reason": data[0]})+"\n")
print(f"❌ permanently failed rows: {fails}")
keep = [i for i,x in enumerate(filled_col) if x]
new_ds = ds.select(keep)
new_ds = new_ds.add_column("bbox_filled", [filled_col[i] for i in keep])
new_ds = new_ds.add_column("annotated", [anno_col[i] for i in keep])
new_ds = new_ds.add_column("bbox_json", [json_col[i] for i in keep])
new_ds = new_ds.cast_column("bbox_filled", HFImage())
new_ds = new_ds.cast_column("annotated", HFImage())
print(f"βœ… successes: {len(new_ds)} / {N}")
print("Columns:", new_ds.column_names)
print("QA artefacts β†’", QA_DIR.resolve())
# optional push
new_ds.push_to_hub("fotographerai/surround_furniture_bboxfilled",
private=True, max_shard_size="500MB")