Spaces:
Running on Zero
Running on Zero
| import sys | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| def log(msg: str): | |
| print(msg, flush=True) | |
| def _make_log_tqdm(): | |
| """tqdm subclass that routes per-file download progress to our log buffer.""" | |
| try: | |
| from tqdm.auto import tqdm as _Base | |
| except ImportError: | |
| from tqdm import tqdm as _Base | |
| class _LogTqdm(_Base): | |
| def __init__(self, *args, **kwargs): | |
| self.__last_pct = -1 | |
| # disable=True suppresses terminal rendering; tracking still works | |
| super().__init__(*args, disable=True, **kwargs) | |
| if self.total and self.total > 100_000: | |
| log(f"[SAM2] Downloading {self.desc or 'file'} ({self.total/1e6:.1f} MB) ...") | |
| def update(self, n=1): | |
| super().update(n) | |
| if not self.total or self.total <= 100_000: | |
| return | |
| pct = min(100, int(self.n / self.total * 100)) | |
| if pct >= self.__last_pct + 10: | |
| log(f"[SAM2] {self.desc}: {self.n/1e6:.0f}/{self.total/1e6:.0f} MB ({pct}%)") | |
| self.__last_pct = pct | |
| def close(self): | |
| super().close() | |
| if self.total and self.total > 100_000 and self.n >= self.total * 0.99: | |
| log(f"[SAM2] {self.desc}: β done") | |
| return _LogTqdm | |
| def load_sam2(): | |
| from huggingface_hub import snapshot_download | |
| from transformers import Sam2Model, Sam2Processor | |
| model_id = "facebook/sam2-hiera-large" | |
| # Phase 1: download (instant if already cached; shows per-file progress if not) | |
| log("[SAM2] Checking model files in HF cache ...") | |
| t0 = time.time() | |
| snapshot_download(model_id, tqdm_class=_make_log_tqdm()) | |
| log(f"[SAM2] Cache ready ({time.time()-t0:.1f}s). Loading processor ...") | |
| # Phase 2: deserialize processor | |
| t1 = time.time() | |
| processor = Sam2Processor.from_pretrained(model_id) | |
| log(f"[SAM2] Processor loaded ({time.time()-t1:.1f}s). Loading model weights ...") | |
| # Phase 3: deserialize model (~1-2 GB into GPU RAM β can take 30-60s) | |
| t2 = time.time() | |
| model = Sam2Model.from_pretrained(model_id) | |
| model.eval() | |
| log(f"[SAM2] Model loaded ({time.time()-t2:.1f}s). Total init: {time.time()-t0:.1f}s.") | |
| return model, processor | |
| _sam2_cache = None | |
| def get_sam2(): | |
| global _sam2_cache | |
| if _sam2_cache is None: | |
| log("[SAM2] Cold start β initializing model for the first time ...") | |
| _sam2_cache = load_sam2() | |
| else: | |
| log("[SAM2] Using cached model.") | |
| return _sam2_cache | |
| # Each prompt: (click_x, click_y, bbox_x1, bbox_y1, bbox_x2, bbox_y2) | |
| # All values normalized [0,1]. Bbox constrains SAM2 to look only within | |
| # that region, which is far more reliable than a point alone for body parts. | |
| DEFAULT_PROMPTS = { | |
| "breast_left": (0.40, 0.36, 0.28, 0.26, 0.50, 0.46), | |
| "breast_right": (0.60, 0.36, 0.50, 0.26, 0.72, 0.46), | |
| "buttocks": (0.50, 0.72, 0.30, 0.62, 0.70, 0.85), | |
| "ponytail": (0.50, 0.10, 0.35, 0.00, 0.65, 0.20), | |
| "hair": (0.50, 0.10, 0.30, 0.00, 0.70, 0.25), | |
| } | |
| ANATOMY_REGIONS = {"breast_left", "breast_right", "buttocks"} | |
| def segment_regions(image: Image.Image, requested: list[str], click_points: dict | None = None) -> dict: | |
| log(f"[Segment] Requested: {requested} | image size: {image.size}") | |
| # Body-region masks come from MediaPipe pose + ellipse β not SAM2. | |
| # SAM2 segments by pixel similarity, which on clothed photos catches the | |
| # tank top / shirt color rather than the underlying anatomy. | |
| anatomy_requests = [r for r in requested if r in ANATOMY_REGIONS] | |
| sam_requests = [r for r in requested if r not in ANATOMY_REGIONS] | |
| results: dict = {} | |
| if anatomy_requests: | |
| from anatomy import segment_anatomy | |
| results.update(segment_anatomy(image, anatomy_requests)) | |
| if not sam_requests: | |
| log(f"[Segment] All {len(results)} regions complete (anatomy only).") | |
| return results | |
| log(f"[SAM2] Falling back to SAM2 for: {sam_requests}") | |
| model, processor = get_sam2() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| log(f"[SAM2] Using device: {device}") | |
| model = model.to(device) | |
| W, H = image.size | |
| for i, region in enumerate(sam_requests): | |
| if region not in DEFAULT_PROMPTS: | |
| log(f"[SAM2] Skipping unknown region: {region}") | |
| continue | |
| log(f"[SAM2] Processing region {i+1}/{len(requested)}: {region} ...") | |
| t = time.time() | |
| prompt = DEFAULT_PROMPTS[region] | |
| if click_points and region in click_points: | |
| px, py = click_points[region] | |
| else: | |
| px, py = prompt[0] * W, prompt[1] * H | |
| # Pass both a click point AND a bounding box. The bbox constrains SAM2 | |
| # to segment only inside that region, which is essential for parts of | |
| # a body where a click alone yields ambiguous results (subpart vs torso | |
| # vs whole subject). | |
| bx1, by1, bx2, by2 = prompt[2] * W, prompt[3] * H, prompt[4] * W, prompt[5] * H | |
| log(f"[SAM2] {region} click=({px:.0f},{py:.0f}) bbox=({bx1:.0f},{by1:.0f},{bx2:.0f},{by2:.0f})") | |
| # 4-level nesting: [image][object][point][xy]; boxes: [image][object][xyxy] | |
| inputs = processor( | |
| images=image, | |
| input_points=[[[[px, py]]]], | |
| input_boxes=[[[bx1, by1, bx2, by2]]], | |
| return_tensors="pt", | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| masks = processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"].cpu(), | |
| )[0] | |
| # SAM2 returns 3 masks (subpart / part / whole). argmax-ing IoU scores | |
| # often picks the "whole subject" mask, which is wrong for body-region | |
| # segmentation β we want the local part the click landed on. | |
| # Pick the smallest mask whose area is between 0.5% and 40% of the image. | |
| scores = outputs.iou_scores[0, 0].cpu().numpy() | |
| mtensor = masks[0].numpy() | |
| if mtensor.ndim == 4: | |
| mtensor = mtensor[0] | |
| # mtensor is now (num_masks, H, W) | |
| total_px = mtensor.shape[1] * mtensor.shape[2] | |
| areas = [int(np.sum(m > 0)) for m in mtensor] | |
| log(f"[SAM2] mask shape: {mtensor.shape}, areas: {areas}, scores: {scores.tolist()}") | |
| # Filter to masks with area between 0.5% and 40% of image, then pick | |
| # the one with the *highest* IoU score (model's own confidence) β not | |
| # the smallest, which often gave us a low-confidence sliver. | |
| candidates = [ | |
| i for i in range(len(mtensor)) | |
| if 0.005 * total_px <= areas[i] <= 0.40 * total_px | |
| ] | |
| if candidates: | |
| best = max(candidates, key=lambda i: scores[i]) | |
| log(f"[SAM2] picked mask idx={best} (highest score within 0.5β40% range, score={scores[best]:.3f}, area={areas[best]})") | |
| else: | |
| best = int(np.argmax(scores)) | |
| log(f"[SAM2] no mask in range β falling back to argmax idx={best}") | |
| mask = mtensor[best].astype(bool) | |
| rows = np.any(mask, axis=1) | |
| cols = np.any(mask, axis=0) | |
| rmin, rmax = np.where(rows)[0][[0, -1]] | |
| cmin, cmax = np.where(cols)[0][[0, -1]] | |
| log(f"[SAM2] '{region}' done in {time.time()-t:.1f}s β bbox=[{cmin},{rmin},{cmax-cmin},{rmax-rmin}] score={scores[best]:.3f}") | |
| results[region] = { | |
| "mask": mask.tolist(), | |
| "bbox": [int(cmin), int(rmin), int(cmax - cmin), int(rmax - rmin)], | |
| } | |
| log(f"[Segment] All {len(results)} regions complete.") | |
| return results | |