Spaces:
Build error
Build error
File size: 10,086 Bytes
7bab192 b84e980 e0cab61 7bab192 f9efdb7 7bab192 b84e980 7bab192 b84e980 e0cab61 7bab192 b84e980 7bab192 e0cab61 7bab192 e0cab61 7bab192 b84e980 7bab192 b84e980 7bab192 d81e58f 7bab192 b84e980 7bab192 d81e58f 7bab192 d81e58f 7bab192 d81e58f 7bab192 d81e58f 7bab192 e714555 7bab192 e714555 7bab192 e714555 7bab192 e0cab61 7bab192 b84e980 7bab192 e714555 7bab192 b84e980 7bab192 b84e980 7bab192 b84e980 7bab192 b84e980 7bab192 b84e980 7bab192 b84e980 7bab192 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 |
import os
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from transformers import SamModel, SamProcessor
import warnings
warnings.filterwarnings("ignore")
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load SAM model and processor
model_id = "facebook/sam-vit-base"
processor = SamProcessor.from_pretrained(model_id)
model = SamModel.from_pretrained(model_id).to(device)
def get_sam_mask(image, points=None):
"""
Generate mask from SAM model based on the entire image
"""
# Convert to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Process image with SAM
if points is None:
# Generate automatic masks for the whole image
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Get the best mask (highest IoU)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0]
# Convert to binary mask and return the largest mask
masks = masks.numpy()
if masks.shape[0] > 0:
# Calculate area of each mask and get the largest one
areas = [np.sum(mask) for mask in masks]
largest_mask_idx = np.argmax(areas)
return masks[largest_mask_idx].astype(np.uint8) * 255
else:
# If no masks found, return full image mask
return np.ones((image.height, image.width), dtype=np.uint8) * 255
else:
# Use the provided points to generate a mask
inputs = processor(images=image, input_points=[points], return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Get the mask
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0]
return masks[0].numpy().astype(np.uint8) * 255
def find_optimal_crop(image, mask, target_aspect_ratio):
"""
Find the optimal crop that preserves important content based on the mask
"""
# Convert PIL image to numpy array
image_np = np.array(image)
h, w = mask.shape
# Find the bounding box of the important content
# First, find where the mask is non-zero (important content)
y_indices, x_indices = np.where(mask > 0)
if len(y_indices) == 0 or len(x_indices) == 0:
# Fallback if no mask is found
content_box = (0, 0, w, h)
else:
# Get the bounding box of important content
min_x, max_x = np.min(x_indices), np.max(x_indices)
min_y, max_y = np.min(y_indices), np.max(y_indices)
content_width = max_x - min_x + 1
content_height = max_y - min_y + 1
content_box = (min_x, min_y, content_width, content_height)
# Calculate target dimensions based on the original image
if target_aspect_ratio > w / h:
# Target is wider than original
target_h = int(w / target_aspect_ratio)
target_w = w
else:
# Target is taller than original
target_h = h
target_w = int(h * target_aspect_ratio)
# Calculate the center of the important content
content_center_x = content_box[0] + content_box[2] // 2
content_center_y = content_box[1] + content_box[3] // 2
# Try to center the crop on the important content
x = max(0, min(content_center_x - target_w // 2, w - target_w))
y = max(0, min(content_center_y - target_h // 2, h - target_h))
# Check if the important content fits within this crop
min_x, min_y, content_width, content_height = content_box
max_x = min_x + content_width
max_y = min_y + content_height
# If the content doesn't fit in the crop, adjust the crop
if target_w >= content_width and target_h >= content_height:
# If the crop is large enough to include all content, center it
x = max(0, min(content_center_x - target_w // 2, w - target_w))
y = max(0, min(content_center_y - target_h // 2, h - target_h))
else:
# If crop isn't large enough for all content, maximize visible content
# and prioritize centering the crop on the content
x = max(0, min(min_x, w - target_w))
y = max(0, min(min_y, h - target_h))
# If we still can't fit width, center the crop horizontally
if content_width > target_w:
x = max(0, min(content_center_x - target_w // 2, w - target_w))
# If we still can't fit height, center the crop vertically
if content_height > target_h:
y = max(0, min(content_center_y - target_h // 2, h - target_h))
return (x, y, x + target_w, y + target_h)
def smart_crop(input_image, target_aspect_ratio, point_x=None, point_y=None):
"""
Main function to perform smart cropping
"""
if input_image is None:
return None
# Open image and convert to RGB
pil_image = Image.fromarray(input_image) if isinstance(input_image, np.ndarray) else input_image
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
# Generate mask using SAM
points = None
if point_x is not None and point_y is not None and point_x > 0 and point_y > 0:
points = [[point_x, point_y]]
mask = get_sam_mask(pil_image, points)
# Calculate the best crop
crop_box = find_optimal_crop(pil_image, mask, target_aspect_ratio)
# Crop the image
cropped_img = pil_image.crop(crop_box)
# Visualize the process
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(pil_image)
ax[0].set_title("Original Image")
ax[0].axis("off")
ax[1].imshow(mask, cmap='gray')
ax[1].set_title("SAM Segmentation Mask")
ax[1].axis("off")
ax[2].imshow(cropped_img)
ax[2].set_title(f"Smart Cropped ({target_aspect_ratio:.2f})")
ax[2].axis("off")
plt.tight_layout()
# Create a temporary file for visualization
vis_path = "visualization.png"
plt.savefig(vis_path)
plt.close()
return cropped_img, vis_path
def aspect_ratio_options(choice):
"""Map aspect ratio choices to actual values"""
options = {
"16:9 (Landscape)": 16/9,
"9:16 (Portrait)": 9/16,
"4:3 (Standard)": 4/3,
"3:4 (Portrait)": 3/4,
"1:1 (Square)": 1/1,
"21:9 (Ultrawide)": 21/9,
"2:3 (Portrait)": 2/3,
"3:2 (Landscape)": 3/2,
}
return options.get(choice, 16/9)
def process_image(input_image, aspect_ratio_choice, point_x=None, point_y=None):
if input_image is None:
return None, None
# Get the actual aspect ratio value
target_aspect_ratio = aspect_ratio_options(aspect_ratio_choice)
# Process the image
result_img, vis_path = smart_crop(input_image, target_aspect_ratio, point_x, point_y)
return result_img, vis_path
def create_app():
with gr.Blocks(title="Smart Image Cropper using SAM") as app:
gr.Markdown("# Smart Image Cropper using Segment Anything Model (SAM)")
gr.Markdown("""
Upload an image and choose your target aspect ratio. The app will use the Segment Anything Model (SAM)
to identify important content and crop intelligently to preserve it.
Optionally, you can click on the uploaded image to specify a point of interest.
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
aspect_ratio = gr.Dropdown(
choices=[
"16:9 (Landscape)",
"9:16 (Portrait)",
"4:3 (Standard)",
"3:4 (Portrait)",
"1:1 (Square)",
"21:9 (Ultrawide)",
"2:3 (Portrait)",
"3:2 (Landscape)"
],
value="16:9 (Landscape)",
label="Target Aspect Ratio"
)
point_coords = gr.State(value=[None, None])
def update_coords(img, evt: gr.SelectData):
return [evt.index[0], evt.index[1]]
input_image.select(update_coords, inputs=[input_image], outputs=[point_coords])
process_btn = gr.Button("Process Image")
with gr.Column(scale=2):
output_image = gr.Image(type="pil", label="Cropped Result")
visualization = gr.Image(type="filepath", label="Process Visualization")
process_btn.click(
fn=lambda img, ratio, coords: process_image(img, ratio, coords[0], coords[1]),
inputs=[input_image, aspect_ratio, point_coords],
outputs=[output_image, visualization]
)
gr.Markdown("""
## How It Works
1. The Segment Anything Model (SAM) analyzes your image to identify the important content
2. The app finds the optimal crop window that maximizes the preservation of that content
3. The image is cropped to your desired aspect ratio while keeping the important parts
## Tips
- For better results with specific subjects, click on the important object in the image
- Try different aspect ratios to see how the model adapts the cropping
""")
return app
# Create and launch the app
demo = create_app()
# For local testing
if __name__ == "__main__":
demo.launch()
else:
# For Hugging Face Spaces
demo.launch() |