import os import gradio as gr import numpy as np import torch import cv2 from PIL import Image import matplotlib.pyplot as plt from transformers import SamModel, SamProcessor import warnings warnings.filterwarnings("ignore") # Check if CUDA is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load SAM model and processor model_id = "facebook/sam-vit-base" processor = SamProcessor.from_pretrained(model_id) model = SamModel.from_pretrained(model_id).to(device) def get_sam_mask(image, points=None): """ Generate mask from SAM model based on the entire image """ # Convert to RGB if needed if image.mode != "RGB": image = image.convert("RGB") # Process image with SAM if points is None: # Generate automatic masks for the whole image inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) # Get the best mask (highest IoU) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() )[0][0] # Convert to binary mask and return the largest mask masks = masks.numpy() if masks.shape[0] > 0: # Calculate area of each mask and get the largest one areas = [np.sum(mask) for mask in masks] largest_mask_idx = np.argmax(areas) return masks[largest_mask_idx].astype(np.uint8) * 255 else: # If no masks found, return full image mask return np.ones((image.height, image.width), dtype=np.uint8) * 255 else: # Use the provided points to generate a mask inputs = processor(images=image, input_points=[points], return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) # Get the mask masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() )[0][0] return masks[0].numpy().astype(np.uint8) * 255 def find_optimal_crop(image, mask, target_aspect_ratio): """ Find the optimal crop that preserves important content based on the mask """ # Convert PIL image to numpy array image_np = np.array(image) h, w = mask.shape # Find the bounding box of the important content # First, find where the mask is non-zero (important content) y_indices, x_indices = np.where(mask > 0) if len(y_indices) == 0 or len(x_indices) == 0: # Fallback if no mask is found content_box = (0, 0, w, h) else: # Get the bounding box of important content min_x, max_x = np.min(x_indices), np.max(x_indices) min_y, max_y = np.min(y_indices), np.max(y_indices) content_width = max_x - min_x + 1 content_height = max_y - min_y + 1 content_box = (min_x, min_y, content_width, content_height) # Calculate target dimensions based on the original image if target_aspect_ratio > w / h: # Target is wider than original target_h = int(w / target_aspect_ratio) target_w = w else: # Target is taller than original target_h = h target_w = int(h * target_aspect_ratio) # Calculate the center of the important content content_center_x = content_box[0] + content_box[2] // 2 content_center_y = content_box[1] + content_box[3] // 2 # Try to center the crop on the important content x = max(0, min(content_center_x - target_w // 2, w - target_w)) y = max(0, min(content_center_y - target_h // 2, h - target_h)) # Check if the important content fits within this crop min_x, min_y, content_width, content_height = content_box max_x = min_x + content_width max_y = min_y + content_height # If the content doesn't fit in the crop, adjust the crop if target_w >= content_width and target_h >= content_height: # If the crop is large enough to include all content, center it x = max(0, min(content_center_x - target_w // 2, w - target_w)) y = max(0, min(content_center_y - target_h // 2, h - target_h)) else: # If crop isn't large enough for all content, maximize visible content # and prioritize centering the crop on the content x = max(0, min(min_x, w - target_w)) y = max(0, min(min_y, h - target_h)) # If we still can't fit width, center the crop horizontally if content_width > target_w: x = max(0, min(content_center_x - target_w // 2, w - target_w)) # If we still can't fit height, center the crop vertically if content_height > target_h: y = max(0, min(content_center_y - target_h // 2, h - target_h)) return (x, y, x + target_w, y + target_h) def smart_crop(input_image, target_aspect_ratio, point_x=None, point_y=None): """ Main function to perform smart cropping """ if input_image is None: return None # Open image and convert to RGB pil_image = Image.fromarray(input_image) if isinstance(input_image, np.ndarray) else input_image if pil_image.mode != "RGB": pil_image = pil_image.convert("RGB") # Generate mask using SAM points = None if point_x is not None and point_y is not None and point_x > 0 and point_y > 0: points = [[point_x, point_y]] mask = get_sam_mask(pil_image, points) # Calculate the best crop crop_box = find_optimal_crop(pil_image, mask, target_aspect_ratio) # Crop the image cropped_img = pil_image.crop(crop_box) # Visualize the process fig, ax = plt.subplots(1, 3, figsize=(15, 5)) ax[0].imshow(pil_image) ax[0].set_title("Original Image") ax[0].axis("off") ax[1].imshow(mask, cmap='gray') ax[1].set_title("SAM Segmentation Mask") ax[1].axis("off") ax[2].imshow(cropped_img) ax[2].set_title(f"Smart Cropped ({target_aspect_ratio:.2f})") ax[2].axis("off") plt.tight_layout() # Create a temporary file for visualization vis_path = "visualization.png" plt.savefig(vis_path) plt.close() return cropped_img, vis_path def aspect_ratio_options(choice): """Map aspect ratio choices to actual values""" options = { "16:9 (Landscape)": 16/9, "9:16 (Portrait)": 9/16, "4:3 (Standard)": 4/3, "3:4 (Portrait)": 3/4, "1:1 (Square)": 1/1, "21:9 (Ultrawide)": 21/9, "2:3 (Portrait)": 2/3, "3:2 (Landscape)": 3/2, } return options.get(choice, 16/9) def process_image(input_image, aspect_ratio_choice, point_x=None, point_y=None): if input_image is None: return None, None # Get the actual aspect ratio value target_aspect_ratio = aspect_ratio_options(aspect_ratio_choice) # Process the image result_img, vis_path = smart_crop(input_image, target_aspect_ratio, point_x, point_y) return result_img, vis_path def create_app(): with gr.Blocks(title="Smart Image Cropper using SAM") as app: gr.Markdown("# Smart Image Cropper using Segment Anything Model (SAM)") gr.Markdown(""" Upload an image and choose your target aspect ratio. The app will use the Segment Anything Model (SAM) to identify important content and crop intelligently to preserve it. Optionally, you can click on the uploaded image to specify a point of interest. """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload Image") aspect_ratio = gr.Dropdown( choices=[ "16:9 (Landscape)", "9:16 (Portrait)", "4:3 (Standard)", "3:4 (Portrait)", "1:1 (Square)", "21:9 (Ultrawide)", "2:3 (Portrait)", "3:2 (Landscape)" ], value="16:9 (Landscape)", label="Target Aspect Ratio" ) point_coords = gr.State(value=[None, None]) def update_coords(img, evt: gr.SelectData): return [evt.index[0], evt.index[1]] input_image.select(update_coords, inputs=[input_image], outputs=[point_coords]) process_btn = gr.Button("Process Image") with gr.Column(scale=2): output_image = gr.Image(type="pil", label="Cropped Result") visualization = gr.Image(type="filepath", label="Process Visualization") process_btn.click( fn=lambda img, ratio, coords: process_image(img, ratio, coords[0], coords[1]), inputs=[input_image, aspect_ratio, point_coords], outputs=[output_image, visualization] ) gr.Markdown(""" ## How It Works 1. The Segment Anything Model (SAM) analyzes your image to identify the important content 2. The app finds the optimal crop window that maximizes the preservation of that content 3. The image is cropped to your desired aspect ratio while keeping the important parts ## Tips - For better results with specific subjects, click on the important object in the image - Try different aspect ratios to see how the model adapts the cropping """) return app # Create and launch the app demo = create_app() # For local testing if __name__ == "__main__": demo.launch() else: # For Hugging Face Spaces demo.launch()