Spaces:
Build error
Build error
import os | |
import gradio as gr | |
import numpy as np | |
import torch | |
import cv2 | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import SamModel, SamProcessor | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Check if CUDA is available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load SAM model and processor | |
model_id = "facebook/sam-vit-base" | |
processor = SamProcessor.from_pretrained(model_id) | |
model = SamModel.from_pretrained(model_id).to(device) | |
def get_sam_mask(image, points=None): | |
""" | |
Generate mask from SAM model based on the entire image | |
""" | |
# Convert to RGB if needed | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Process image with SAM | |
if points is None: | |
# Generate automatic masks for the whole image | |
inputs = processor(images=image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get the best mask (highest IoU) | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
)[0][0] | |
# Convert to binary mask and return the largest mask | |
masks = masks.numpy() | |
if masks.shape[0] > 0: | |
# Calculate area of each mask and get the largest one | |
areas = [np.sum(mask) for mask in masks] | |
largest_mask_idx = np.argmax(areas) | |
return masks[largest_mask_idx].astype(np.uint8) * 255 | |
else: | |
# If no masks found, return full image mask | |
return np.ones((image.height, image.width), dtype=np.uint8) * 255 | |
else: | |
# Use the provided points to generate a mask | |
inputs = processor(images=image, input_points=[points], return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get the mask | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
)[0][0] | |
return masks[0].numpy().astype(np.uint8) * 255 | |
def find_optimal_crop(image, mask, target_aspect_ratio): | |
""" | |
Find the optimal crop that preserves important content based on the mask | |
""" | |
# Convert PIL image to numpy array | |
image_np = np.array(image) | |
h, w = mask.shape | |
# Find the bounding box of the important content | |
# First, find where the mask is non-zero (important content) | |
y_indices, x_indices = np.where(mask > 0) | |
if len(y_indices) == 0 or len(x_indices) == 0: | |
# Fallback if no mask is found | |
content_box = (0, 0, w, h) | |
else: | |
# Get the bounding box of important content | |
min_x, max_x = np.min(x_indices), np.max(x_indices) | |
min_y, max_y = np.min(y_indices), np.max(y_indices) | |
content_width = max_x - min_x + 1 | |
content_height = max_y - min_y + 1 | |
content_box = (min_x, min_y, content_width, content_height) | |
# Calculate target dimensions based on the original image | |
if target_aspect_ratio > w / h: | |
# Target is wider than original | |
target_h = int(w / target_aspect_ratio) | |
target_w = w | |
else: | |
# Target is taller than original | |
target_h = h | |
target_w = int(h * target_aspect_ratio) | |
# Calculate the center of the important content | |
content_center_x = content_box[0] + content_box[2] // 2 | |
content_center_y = content_box[1] + content_box[3] // 2 | |
# Try to center the crop on the important content | |
x = max(0, min(content_center_x - target_w // 2, w - target_w)) | |
y = max(0, min(content_center_y - target_h // 2, h - target_h)) | |
# Check if the important content fits within this crop | |
min_x, min_y, content_width, content_height = content_box | |
max_x = min_x + content_width | |
max_y = min_y + content_height | |
# If the content doesn't fit in the crop, adjust the crop | |
if target_w >= content_width and target_h >= content_height: | |
# If the crop is large enough to include all content, center it | |
x = max(0, min(content_center_x - target_w // 2, w - target_w)) | |
y = max(0, min(content_center_y - target_h // 2, h - target_h)) | |
else: | |
# If crop isn't large enough for all content, maximize visible content | |
# and prioritize centering the crop on the content | |
x = max(0, min(min_x, w - target_w)) | |
y = max(0, min(min_y, h - target_h)) | |
# If we still can't fit width, center the crop horizontally | |
if content_width > target_w: | |
x = max(0, min(content_center_x - target_w // 2, w - target_w)) | |
# If we still can't fit height, center the crop vertically | |
if content_height > target_h: | |
y = max(0, min(content_center_y - target_h // 2, h - target_h)) | |
return (x, y, x + target_w, y + target_h) | |
def smart_crop(input_image, target_aspect_ratio, point_x=None, point_y=None): | |
""" | |
Main function to perform smart cropping | |
""" | |
if input_image is None: | |
return None | |
# Open image and convert to RGB | |
pil_image = Image.fromarray(input_image) if isinstance(input_image, np.ndarray) else input_image | |
if pil_image.mode != "RGB": | |
pil_image = pil_image.convert("RGB") | |
# Generate mask using SAM | |
points = None | |
if point_x is not None and point_y is not None and point_x > 0 and point_y > 0: | |
points = [[point_x, point_y]] | |
mask = get_sam_mask(pil_image, points) | |
# Calculate the best crop | |
crop_box = find_optimal_crop(pil_image, mask, target_aspect_ratio) | |
# Crop the image | |
cropped_img = pil_image.crop(crop_box) | |
# Visualize the process | |
fig, ax = plt.subplots(1, 3, figsize=(15, 5)) | |
ax[0].imshow(pil_image) | |
ax[0].set_title("Original Image") | |
ax[0].axis("off") | |
ax[1].imshow(mask, cmap='gray') | |
ax[1].set_title("SAM Segmentation Mask") | |
ax[1].axis("off") | |
ax[2].imshow(cropped_img) | |
ax[2].set_title(f"Smart Cropped ({target_aspect_ratio:.2f})") | |
ax[2].axis("off") | |
plt.tight_layout() | |
# Create a temporary file for visualization | |
vis_path = "visualization.png" | |
plt.savefig(vis_path) | |
plt.close() | |
return cropped_img, vis_path | |
def aspect_ratio_options(choice): | |
"""Map aspect ratio choices to actual values""" | |
options = { | |
"16:9 (Landscape)": 16/9, | |
"9:16 (Portrait)": 9/16, | |
"4:3 (Standard)": 4/3, | |
"3:4 (Portrait)": 3/4, | |
"1:1 (Square)": 1/1, | |
"21:9 (Ultrawide)": 21/9, | |
"2:3 (Portrait)": 2/3, | |
"3:2 (Landscape)": 3/2, | |
} | |
return options.get(choice, 16/9) | |
def process_image(input_image, aspect_ratio_choice, point_x=None, point_y=None): | |
if input_image is None: | |
return None, None | |
# Get the actual aspect ratio value | |
target_aspect_ratio = aspect_ratio_options(aspect_ratio_choice) | |
# Process the image | |
result_img, vis_path = smart_crop(input_image, target_aspect_ratio, point_x, point_y) | |
return result_img, vis_path | |
def create_app(): | |
with gr.Blocks(title="Smart Image Cropper using SAM") as app: | |
gr.Markdown("# Smart Image Cropper using Segment Anything Model (SAM)") | |
gr.Markdown(""" | |
Upload an image and choose your target aspect ratio. The app will use the Segment Anything Model (SAM) | |
to identify important content and crop intelligently to preserve it. | |
Optionally, you can click on the uploaded image to specify a point of interest. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(type="pil", label="Upload Image") | |
aspect_ratio = gr.Dropdown( | |
choices=[ | |
"16:9 (Landscape)", | |
"9:16 (Portrait)", | |
"4:3 (Standard)", | |
"3:4 (Portrait)", | |
"1:1 (Square)", | |
"21:9 (Ultrawide)", | |
"2:3 (Portrait)", | |
"3:2 (Landscape)" | |
], | |
value="16:9 (Landscape)", | |
label="Target Aspect Ratio" | |
) | |
point_coords = gr.State(value=[None, None]) | |
def update_coords(img, evt: gr.SelectData): | |
return [evt.index[0], evt.index[1]] | |
input_image.select(update_coords, inputs=[input_image], outputs=[point_coords]) | |
process_btn = gr.Button("Process Image") | |
with gr.Column(scale=2): | |
output_image = gr.Image(type="pil", label="Cropped Result") | |
visualization = gr.Image(type="filepath", label="Process Visualization") | |
process_btn.click( | |
fn=lambda img, ratio, coords: process_image(img, ratio, coords[0], coords[1]), | |
inputs=[input_image, aspect_ratio, point_coords], | |
outputs=[output_image, visualization] | |
) | |
gr.Markdown(""" | |
## How It Works | |
1. The Segment Anything Model (SAM) analyzes your image to identify the important content | |
2. The app finds the optimal crop window that maximizes the preservation of that content | |
3. The image is cropped to your desired aspect ratio while keeping the important parts | |
## Tips | |
- For better results with specific subjects, click on the important object in the image | |
- Try different aspect ratios to see how the model adapts the cropping | |
""") | |
return app | |
# Create and launch the app | |
demo = create_app() | |
# For local testing | |
if __name__ == "__main__": | |
demo.launch() | |
else: | |
# For Hugging Face Spaces | |
demo.launch() |