Spaces:
Running
Running
#!/usr/bin/env python | |
# furniture_bbox_to_files.py ββββββββββββββββββββββββββββββββββββββββ | |
# Florence-2 + SAM-2 batch processor with retries *and* file-based images | |
# -------------------------------------------------------------------- | |
import os, json, random, time | |
from pathlib import Path | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from typing import List | |
import torch, supervision as sv | |
from PIL import Image, ImageDraw, ImageColor, ImageOps | |
from tqdm.auto import tqdm | |
from datasets import load_dataset, Image as HFImage, disable_progress_bar | |
# βββββ global models ββββββββββββββββββββββββββββββββββββββββββββββββ | |
from utils.florence import ( | |
load_florence_model, run_florence_inference, | |
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, | |
) | |
from utils.sam import load_sam_image_model, run_sam_inference | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE) | |
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) | |
# annotators | |
_PALETTE = sv.ColorPalette.from_hex( | |
['#FF1493','#00BFFF','#FF6347','#FFD700','#32CD32','#8A2BE2']) | |
BOX_ANN = sv.BoxAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX) | |
MASK_ANN = sv.MaskAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX) | |
LBL_ANN = sv.LabelAnnotator( | |
color=_PALETTE, color_lookup=sv.ColorLookup.INDEX, | |
text_position=sv.Position.CENTER_OF_MASS, | |
text_color=sv.Color.from_hex("#000"), border_radius=5) | |
# βββββ config βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
disable_progress_bar() | |
DATASET_NAME = "fotographerai/furniture_captioned_segment_prompt" | |
SPLIT = "train" | |
IMAGE_COL = "img2" | |
PROMPT_COL = "segmenting_prompt" | |
INFLATE_RANGE = (0.01, 0.05) | |
FILL_COLOR = "#00FF00" | |
TARGET_SIDE = 1500 | |
QA_DIR = Path("bbox_review_recaptioned") | |
GREEN_DIR = QA_DIR / "green"; GREEN_DIR.mkdir(parents=True, exist_ok=True) | |
ANNO_DIR = QA_DIR / "anno"; ANNO_DIR.mkdir(parents=True, exist_ok=True) | |
JSON_DIR = QA_DIR / "json"; JSON_DIR.mkdir(parents=True, exist_ok=True) | |
MAX_WORKERS = 100 | |
MAX_RETRIES = 5 | |
RETRY_SLEEP = .3 | |
FAILED_LOG = QA_DIR / "failed_rows.jsonl" | |
PROMPT_MAP: dict[str,str] = {} # optional overrides | |
# βββββ helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def make_square(img: Image.Image, side: int = TARGET_SIDE) -> Image.Image: | |
img = ImageOps.contain(img, (side, side)) | |
pad_w, pad_h = side - img.width, side - img.height | |
return ImageOps.expand(img, border=(pad_w//2, pad_h//2, | |
pad_w - pad_w//2, pad_h - pad_h//2), | |
fill=img.getpixel((0,0))) | |
def img_to_file(img: Image.Image, fname: str, folder: Path) -> dict: | |
path = folder / f"{fname}.png" | |
if not path.exists(): | |
img.save(path) | |
return {"path": str(path), "bytes": None} | |
# βββββ core functions βββββββββββββββββββββββββββββββββββββββββββββββ | |
def detect_and_segment(img: Image.Image, prompts: str|List[str]) -> sv.Detections: | |
if isinstance(prompts, str): | |
prompts = [p.strip() for p in prompts.split(",") if p.strip()] | |
all_dets = [] | |
for p in prompts: | |
_, res = run_florence_inference( | |
model=FLORENCE_MODEL, processor=FLORENCE_PROC, device=DEVICE, | |
image=img, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=p) | |
d = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, res, img.size) | |
all_dets.append(run_sam_inference(SAM_IMAGE_MODEL, img, d)) | |
return sv.Detections.merge(all_dets) | |
def fill_detected_bboxes(img: Image.Image, prompt: str, | |
inflate_pct: float) -> tuple[Image.Image, sv.Detections]: | |
dets = detect_and_segment(img, prompt) | |
filled = img.copy() | |
draw = ImageDraw.Draw(filled) | |
rgb = ImageColor.getrgb(FILL_COLOR) | |
w,h = img.size | |
for box in dets.xyxy: | |
x1,y1,x2,y2 = box.astype(float) | |
dw,dh = (x2-x1)*inflate_pct, (y2-y1)*inflate_pct | |
draw.rectangle([max(0,x1-dw), max(0,y1-dh), | |
min(w,x2+dw), min(h,y2+dh)], fill=rgb) | |
return filled, dets | |
# βββββ threaded worker ββββββββββββββββββββββββββββββββββββββββββββββ | |
def process_row(idx: int, sample): | |
prompt = PROMPT_MAP.get(sample[PROMPT_COL], | |
sample[PROMPT_COL].split(",",1)[0].strip()) | |
img_sq = make_square(sample[IMAGE_COL].convert("RGB")) | |
for attempt in range(1, MAX_RETRIES+1): | |
try: | |
filled, dets = fill_detected_bboxes( | |
img_sq, prompt, inflate_pct=random.uniform(*INFLATE_RANGE)) | |
if len(dets.xyxy) == 0: | |
raise ValueError("no detections") | |
sid = f"{idx:06d}" | |
json_p = JSON_DIR / f"{sid}_bbox.json" | |
json_p.write_text(json.dumps({"xyxy": dets.xyxy.tolist()})) | |
anno = img_sq.copy() | |
for ann in (MASK_ANN, BOX_ANN, LABEL_ANN): | |
anno = ann.annotate(anno, dets) | |
return ("ok", | |
img_to_file(filled, sid, GREEN_DIR), | |
img_to_file(anno, sid, ANNO_DIR), | |
json_p.read_text()) | |
except Exception as e: | |
if attempt < MAX_RETRIES: | |
time.sleep(RETRY_SLEEP) | |
else: | |
return ("fail", str(e)) | |
# βββββ run batch ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
ds = load_dataset(DATASET_NAME, split=SPLIT, streaming=False) | |
N = len(ds) | |
print("Rows:", N) | |
filled_col, anno_col, json_col = [None]*N, [None]*N, [None]*N | |
fails = 0 | |
with ThreadPoolExecutor(MAX_WORKERS) as pool: | |
fut2idx = {pool.submit(process_row, i, ds[i]): i for i in range(N)} | |
for fut in tqdm(as_completed(fut2idx), total=N, desc="Florence+SAM"): | |
idx = fut2idx[fut] | |
status, *data = fut.result() | |
if status == "ok": | |
filled_col[idx], anno_col[idx], json_col[idx] = data | |
else: | |
fails += 1 | |
FAILED_LOG.write_text(json.dumps({"idx": idx, "reason": data[0]})+"\n") | |
print(f"β permanently failed rows: {fails}") | |
keep = [i for i,x in enumerate(filled_col) if x] | |
new_ds = ds.select(keep) | |
new_ds = new_ds.add_column("bbox_filled", [filled_col[i] for i in keep]) | |
new_ds = new_ds.add_column("annotated", [anno_col[i] for i in keep]) | |
new_ds = new_ds.add_column("bbox_json", [json_col[i] for i in keep]) | |
new_ds = new_ds.cast_column("bbox_filled", HFImage()) | |
new_ds = new_ds.cast_column("annotated", HFImage()) | |
print(f"β successes: {len(new_ds)} / {N}") | |
print("Columns:", new_ds.column_names) | |
print("QA artefacts β", QA_DIR.resolve()) | |
# optional push | |
new_ds.push_to_hub("fotographerai/surround_furniture_bboxfilled", | |
private=True, max_shard_size="500MB") | |