Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| """ | |
| Standalone script: Given two images, generate a final difference mask using the | |
| same pipeline as visualize_mask_diff (without any visualization output). | |
| Pipeline: | |
| 1) Align images to a preferred resolution/crop so they share the same size. | |
| 2) Pixel-diff screening across parameter combinations; skip if any hull ratio is | |
| outside [hull_min_allowed, hull_max_allowed]. | |
| 3) Color-diff to produce the final mask; remove small areas and re-check hull | |
| ratio. Save final mask to output path. | |
| """ | |
| import os | |
| import json | |
| import argparse | |
| from typing import Tuple, Optional | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| PREFERRED_KONTEXT_RESOLUTIONS = [ | |
| (672, 1568), (688, 1504), (720, 1456), (752, 1392), (800, 1328), | |
| (832, 1248), (880, 1184), (944, 1104), (1024, 1024), (1104, 944), | |
| (1184, 880), (1248, 832), (1328, 800), (1392, 752), (1456, 720), | |
| (1504, 688), (1568, 672), | |
| ] | |
| def choose_preferred_resolution(image_width: int, image_height: int) -> Tuple[int, int]: | |
| aspect_ratio = image_width / max(1, image_height) | |
| best = min(((abs(aspect_ratio - (w / h)), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]) | |
| _, w_best, h_best = best | |
| return int(w_best), int(h_best) | |
| def align_images(source_path: str, target_path: str) -> Tuple[Image.Image, Image.Image]: | |
| source_img = Image.open(source_path).convert("RGB") | |
| target_img = Image.open(target_path).convert("RGB") | |
| pref_w, pref_h = choose_preferred_resolution(source_img.width, source_img.height) | |
| source_resized = source_img.resize((pref_w, pref_h), Image.Resampling.LANCZOS) | |
| tgt_w, tgt_h = target_img.width, target_img.height | |
| crop_w = min(source_resized.width, tgt_w) | |
| crop_h = min(source_resized.height, tgt_h) | |
| source_aligned = source_resized.crop((0, 0, crop_w, crop_h)) | |
| target_aligned = target_img.crop((0, 0, crop_w, crop_h)) | |
| return source_aligned, target_aligned | |
| def pil_to_cv_gray(img: Image.Image) -> np.ndarray: | |
| bgr = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) | |
| return gray | |
| def generate_pixel_diff_mask(img1: Image.Image, img2: Image.Image, threshold: Optional[int] = None, clean_kernel_size: Optional[int] = 11) -> np.ndarray: | |
| img1_gray = pil_to_cv_gray(img1) | |
| img2_gray = pil_to_cv_gray(img2) | |
| diff = cv2.absdiff(img1_gray, img2_gray) | |
| if threshold is None: | |
| mask = cv2.threshold(diff, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] | |
| else: | |
| mask = cv2.threshold(diff, int(threshold), 255, cv2.THRESH_BINARY)[1] | |
| if clean_kernel_size and clean_kernel_size > 0: | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (clean_kernel_size, clean_kernel_size)) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
| return mask | |
| def generate_color_diff_mask(img1: Image.Image, img2: Image.Image, threshold: Optional[int] = None, clean_kernel_size: Optional[int] = 21) -> np.ndarray: | |
| bgr1 = cv2.cvtColor(np.array(img1), cv2.COLOR_RGB2BGR) | |
| bgr2 = cv2.cvtColor(np.array(img2), cv2.COLOR_RGB2BGR) | |
| lab1 = cv2.cvtColor(bgr1, cv2.COLOR_BGR2LAB).astype("float32") | |
| lab2 = cv2.cvtColor(bgr2, cv2.COLOR_BGR2LAB).astype("float32") | |
| diff = lab1 - lab2 | |
| dist = np.sqrt(np.sum(diff * diff, axis=2)) | |
| dist_u8 = cv2.normalize(dist, None, 0, 255, cv2.NORM_MINMAX).astype("uint8") | |
| if threshold is None: | |
| mask = cv2.threshold(dist_u8, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] | |
| else: | |
| mask = cv2.threshold(dist_u8, int(threshold), 255, cv2.THRESH_BINARY)[1] | |
| if clean_kernel_size and clean_kernel_size > 0: | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (clean_kernel_size, clean_kernel_size)) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
| return mask | |
| def compute_unified_contour(mask_bin: np.ndarray, contours: list, min_area: int = 40, method: str = "morph", morph_kernel: int = 15, morph_iters: int = 1, approx_epsilon_ratio: float = 0.01): | |
| valid_cnts = [] | |
| for c in contours: | |
| if cv2.contourArea(c) >= max(1, min_area): | |
| valid_cnts.append(c) | |
| if not valid_cnts: | |
| return None | |
| if method == "convex_hull": | |
| all_points = np.vstack(valid_cnts) | |
| hull = cv2.convexHull(all_points) | |
| epsilon = approx_epsilon_ratio * cv2.arcLength(hull, True) | |
| unified = cv2.approxPolyDP(hull, epsilon, True) | |
| return unified | |
| union = np.zeros_like(mask_bin) | |
| cv2.drawContours(union, valid_cnts, -1, 255, thickness=-1) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (morph_kernel, morph_kernel)) | |
| union_closed = union.copy() | |
| for _ in range(max(1, morph_iters)): | |
| union_closed = cv2.morphologyEx(union_closed, cv2.MORPH_CLOSE, kernel) | |
| ext = cv2.findContours(union_closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| ext = ext[0] if len(ext) == 2 else ext[1] | |
| if not ext: | |
| return None | |
| largest = max(ext, key=cv2.contourArea) | |
| epsilon = approx_epsilon_ratio * cv2.arcLength(largest, True) | |
| unified = cv2.approxPolyDP(largest, epsilon, True) | |
| return unified | |
| def compute_hull_area_ratio(mask: np.ndarray, min_area: int = 40) -> float: | |
| mask_bin = (mask > 0).astype("uint8") * 255 | |
| cnts = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cnts = cnts[0] if len(cnts) == 2 else cnts[1] | |
| if not cnts: | |
| return 0.0 | |
| hull_cnt = compute_unified_contour(mask_bin, cnts, min_area=min_area, method="convex_hull", morph_kernel=15, morph_iters=1) | |
| if hull_cnt is None or len(hull_cnt) < 3: | |
| return 0.0 | |
| hull_area = float(cv2.contourArea(hull_cnt)) | |
| img_area = float(mask_bin.shape[0] * mask_bin.shape[1]) | |
| return hull_area / max(1.0, img_area) | |
| def clean_and_fill_mask(mask: np.ndarray, min_area: int = 40) -> np.ndarray: | |
| mask_bin = (mask > 0).astype("uint8") * 255 | |
| cnts = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cnts = cnts[0] if len(cnts) == 2 else cnts[1] | |
| cleaned = np.zeros_like(mask_bin) | |
| for c in cnts: | |
| if cv2.contourArea(c) >= max(1, min_area): | |
| cv2.drawContours(cleaned, [c], 0, 255, -1) | |
| return cleaned | |
| def generate_final_difference_mask(source_path: str, | |
| target_path: str, | |
| hull_min_allowed: float = 0.001, | |
| hull_max_allowed: float = 0.75, | |
| pixel_parameters: Optional[list] = None, | |
| pixel_clean_kernel_default: int = 11, | |
| color_clean_kernel: int = 3, | |
| roll_radius: int = 0, | |
| roll_iters: int = 1) -> Optional[np.ndarray]: | |
| if pixel_parameters is None: | |
| # Mirrors the tuned combinations used in visualization script | |
| pixel_parameters = [(None, 5), (None, 11), (50, 5)] | |
| src_img, tgt_img = align_images(source_path, target_path) | |
| # Pixel screening across parameter combinations | |
| violation = False | |
| for thr, ksize in pixel_parameters: | |
| pm = generate_pixel_diff_mask(src_img, tgt_img, threshold=thr, clean_kernel_size=ksize) | |
| r = compute_hull_area_ratio(pm, min_area=40) | |
| if r < hull_min_allowed or r > hull_max_allowed: | |
| violation = True | |
| break | |
| if violation: | |
| # Failure: do not produce any mask | |
| return None | |
| # Color-based final mask → cleaned small areas | |
| color_mask = generate_color_diff_mask(src_img, tgt_img, threshold=None, clean_kernel_size=color_clean_kernel) | |
| cleaned = clean_and_fill_mask(color_mask, min_area=40) | |
| # Produce binary mask from the convex hull contour of the cleaned mask | |
| mask_bin = (cleaned > 0).astype("uint8") * 255 | |
| cnts = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cnts = cnts[0] if len(cnts) == 2 else cnts[1] | |
| hull_cnt = compute_unified_contour(mask_bin, cnts, min_area=40, method="convex_hull", morph_kernel=15, morph_iters=1) | |
| if hull_cnt is None or len(hull_cnt) < 3: | |
| return None | |
| h_mask = np.zeros_like(mask_bin) | |
| cv2.drawContours(h_mask, [hull_cnt], -1, 255, thickness=-1) | |
| # Rolling-circle smoothing: closing then opening with a disk of radius R | |
| if roll_radius and roll_radius > 0 and roll_iters and roll_iters > 0: | |
| ksize = max(1, 2 * int(roll_radius) + 1) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize)) | |
| for _ in range(max(1, roll_iters)): | |
| h_mask = cv2.morphologyEx(h_mask, cv2.MORPH_CLOSE, kernel) | |
| h_mask = cv2.morphologyEx(h_mask, cv2.MORPH_OPEN, kernel) | |
| # Final hull ratio check on the hull-filled binary mask | |
| r_final = compute_hull_area_ratio(h_mask, min_area=40) | |
| if r_final > hull_max_allowed or r_final < hull_min_allowed: | |
| return None | |
| return h_mask | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Generate final difference mask (single pair or whole dataset)") | |
| # Single-pair mode (optional): if provided, runs single pair; otherwise runs dataset mode | |
| parser.add_argument("--source", help="Path to source image") | |
| parser.add_argument("--target", help="Path to target image") | |
| parser.add_argument("--output", help="Path to write the final mask (PNG)") | |
| # Dataset mode (defaults to user's dataset paths) | |
| parser.add_argument("--dataset_dir", default="/home/lzc/KontextFill/InstructV2V/extracted_dataset", help="Base dataset dir with source_images/ and target_images/") | |
| parser.add_argument("--dataset_output_dir", default="/home/lzc/KontextFill/visualizations_masks/inference_masks_smoothing", help="Output directory for batch masks") | |
| parser.add_argument("--json_path", default="/home/lzc/KontextFill/InstructV2V/extracted_dataset/extracted_data.json", help="Dataset JSON mapping with fields 'source_image' and 'target_image'") | |
| # Common params | |
| parser.add_argument("--hull_min_allowed", type=float, default=0.001) | |
| parser.add_argument("--hull_max_allowed", type=float, default=0.75) | |
| parser.add_argument("--color_clean_kernel", type=int, default=3) | |
| parser.add_argument("--roll_radius", type=int, default=15, help="Rolling-circle smoothing radius (pixels); 0 disables") | |
| parser.add_argument("--roll_iters", type=int, default=5, help="Rolling smoothing iterations") | |
| args = parser.parse_args() | |
| pixel_parameters = [(None, 5), (None, 11), (50, 5)] | |
| # Decide mode: single or dataset | |
| if args.source and args.target and args.output: | |
| mask = generate_final_difference_mask( | |
| source_path=args.source, | |
| target_path=args.target, | |
| hull_min_allowed=args.hull_min_allowed, | |
| hull_max_allowed=args.hull_max_allowed, | |
| pixel_parameters=pixel_parameters, | |
| color_clean_kernel=args.color_clean_kernel, | |
| roll_radius=args.roll_radius, | |
| roll_iters=args.roll_iters, | |
| ) | |
| if mask is None: | |
| print("Single-pair inference failed; no output saved.") | |
| return | |
| os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) | |
| cv2.imwrite(args.output, mask) | |
| return | |
| # Dataset mode using JSON mapping | |
| out_dir = args.dataset_output_dir | |
| os.makedirs(out_dir, exist_ok=True) | |
| processed = 0 | |
| skipped = 0 | |
| failed = 0 | |
| missing_files = 0 | |
| try: | |
| with open(args.json_path, "r", encoding="utf-8") as f: | |
| entries = json.load(f) | |
| except Exception as e: | |
| print(f"Failed to read JSON mapping at {args.json_path}: {e}") | |
| entries = [] | |
| for item in entries: | |
| try: | |
| src_rel = item.get("source_image") | |
| tgt_rel = item.get("target_image") | |
| edit_id = item.get("id") | |
| if not src_rel or not tgt_rel: | |
| skipped += 1 | |
| continue | |
| s = os.path.join(args.dataset_dir, src_rel) | |
| t = os.path.join(args.dataset_dir, tgt_rel) | |
| if not (os.path.exists(s) and os.path.exists(t)): | |
| missing_files += 1 | |
| continue | |
| mask = generate_final_difference_mask( | |
| source_path=s, | |
| target_path=t, | |
| hull_min_allowed=args.hull_min_allowed, | |
| hull_max_allowed=args.hull_max_allowed, | |
| pixel_parameters=pixel_parameters, | |
| color_clean_kernel=args.color_clean_kernel, | |
| roll_radius=args.roll_radius, | |
| roll_iters=args.roll_iters, | |
| ) | |
| if mask is None: | |
| failed += 1 | |
| continue | |
| name = f"edit_{int(edit_id):04d}" if isinstance(edit_id, int) or (isinstance(edit_id, str) and edit_id.isdigit()) else os.path.splitext(os.path.basename(src_rel))[0] | |
| out_path = os.path.join(out_dir, f"{name}.png") | |
| cv2.imwrite(out_path, mask) | |
| processed += 1 | |
| except Exception as e: | |
| skipped += 1 | |
| continue | |
| print(f"Batch done. Processed={processed}, Failed={failed}, Skipped={skipped}, MissingFiles={missing_files}, OutputDir={out_dir}") | |
| if __name__ == "__main__": | |
| main() | |