File size: 10,086 Bytes
7bab192
b84e980
 
 
e0cab61
7bab192
 
f9efdb7
7bab192
 
b84e980
7bab192
 
b84e980
 
e0cab61
7bab192
 
 
b84e980
7bab192
e0cab61
7bab192
e0cab61
7bab192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b84e980
7bab192
 
 
 
 
 
 
 
 
 
 
 
 
b84e980
7bab192
 
 
 
 
 
 
 
d81e58f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bab192
 
 
 
b84e980
7bab192
 
 
 
d81e58f
 
 
 
 
 
 
7bab192
d81e58f
 
 
 
7bab192
d81e58f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bab192
d81e58f
7bab192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e714555
7bab192
 
e714555
7bab192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e714555
7bab192
 
 
 
 
 
 
 
 
 
 
 
 
e0cab61
7bab192
 
 
b84e980
7bab192
 
e714555
7bab192
 
b84e980
7bab192
b84e980
7bab192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b84e980
7bab192
 
 
 
 
 
 
 
 
 
 
 
b84e980
7bab192
 
 
 
 
 
 
 
 
 
 
 
 
 
b84e980
7bab192
b84e980
7bab192
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import os
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from transformers import SamModel, SamProcessor
import warnings
warnings.filterwarnings("ignore")

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load SAM model and processor
model_id = "facebook/sam-vit-base"
processor = SamProcessor.from_pretrained(model_id)
model = SamModel.from_pretrained(model_id).to(device)

def get_sam_mask(image, points=None):
    """
    Generate mask from SAM model based on the entire image
    """
    # Convert to RGB if needed
    if image.mode != "RGB":
        image = image.convert("RGB")
    
    # Process image with SAM
    if points is None:
        # Generate automatic masks for the whole image
        inputs = processor(images=image, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Get the best mask (highest IoU)
        masks = processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(),
            inputs["original_sizes"].cpu(),
            inputs["reshaped_input_sizes"].cpu()
        )[0][0]
        
        # Convert to binary mask and return the largest mask
        masks = masks.numpy()
        if masks.shape[0] > 0:
            # Calculate area of each mask and get the largest one
            areas = [np.sum(mask) for mask in masks]
            largest_mask_idx = np.argmax(areas)
            return masks[largest_mask_idx].astype(np.uint8) * 255
        else:
            # If no masks found, return full image mask
            return np.ones((image.height, image.width), dtype=np.uint8) * 255
    else:
        # Use the provided points to generate a mask
        inputs = processor(images=image, input_points=[points], return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Get the mask
        masks = processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(),
            inputs["original_sizes"].cpu(),
            inputs["reshaped_input_sizes"].cpu()
        )[0][0]
        
        return masks[0].numpy().astype(np.uint8) * 255

def find_optimal_crop(image, mask, target_aspect_ratio):
    """
    Find the optimal crop that preserves important content based on the mask
    """
    # Convert PIL image to numpy array
    image_np = np.array(image)
    h, w = mask.shape
    
    # Find the bounding box of the important content
    # First, find where the mask is non-zero (important content)
    y_indices, x_indices = np.where(mask > 0)
    
    if len(y_indices) == 0 or len(x_indices) == 0:
        # Fallback if no mask is found
        content_box = (0, 0, w, h)
    else:
        # Get the bounding box of important content
        min_x, max_x = np.min(x_indices), np.max(x_indices)
        min_y, max_y = np.min(y_indices), np.max(y_indices)
        content_width = max_x - min_x + 1
        content_height = max_y - min_y + 1
        content_box = (min_x, min_y, content_width, content_height)
    
    # Calculate target dimensions based on the original image
    if target_aspect_ratio > w / h:
        # Target is wider than original
        target_h = int(w / target_aspect_ratio)
        target_w = w
    else:
        # Target is taller than original
        target_h = h
        target_w = int(h * target_aspect_ratio)
    
    # Calculate the center of the important content
    content_center_x = content_box[0] + content_box[2] // 2
    content_center_y = content_box[1] + content_box[3] // 2
    
    # Try to center the crop on the important content
    x = max(0, min(content_center_x - target_w // 2, w - target_w))
    y = max(0, min(content_center_y - target_h // 2, h - target_h))
    
    # Check if the important content fits within this crop
    min_x, min_y, content_width, content_height = content_box
    max_x = min_x + content_width
    max_y = min_y + content_height
    
    # If the content doesn't fit in the crop, adjust the crop
    if target_w >= content_width and target_h >= content_height:
        # If the crop is large enough to include all content, center it
        x = max(0, min(content_center_x - target_w // 2, w - target_w))
        y = max(0, min(content_center_y - target_h // 2, h - target_h))
    else:
        # If crop isn't large enough for all content, maximize visible content
        # and prioritize centering the crop on the content
        x = max(0, min(min_x, w - target_w))
        y = max(0, min(min_y, h - target_h))
        
        # If we still can't fit width, center the crop horizontally
        if content_width > target_w:
            x = max(0, min(content_center_x - target_w // 2, w - target_w))
        
        # If we still can't fit height, center the crop vertically
        if content_height > target_h:
            y = max(0, min(content_center_y - target_h // 2, h - target_h))
    
    return (x, y, x + target_w, y + target_h)

def smart_crop(input_image, target_aspect_ratio, point_x=None, point_y=None):
    """
    Main function to perform smart cropping
    """
    if input_image is None:
        return None
    
    # Open image and convert to RGB
    pil_image = Image.fromarray(input_image) if isinstance(input_image, np.ndarray) else input_image
    if pil_image.mode != "RGB":
        pil_image = pil_image.convert("RGB")
    
    # Generate mask using SAM
    points = None
    if point_x is not None and point_y is not None and point_x > 0 and point_y > 0:
        points = [[point_x, point_y]]
    
    mask = get_sam_mask(pil_image, points)
    
    # Calculate the best crop
    crop_box = find_optimal_crop(pil_image, mask, target_aspect_ratio)
    
    # Crop the image
    cropped_img = pil_image.crop(crop_box)
    
    # Visualize the process
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(pil_image)
    ax[0].set_title("Original Image")
    ax[0].axis("off")
    
    ax[1].imshow(mask, cmap='gray')
    ax[1].set_title("SAM Segmentation Mask")
    ax[1].axis("off")
    
    ax[2].imshow(cropped_img)
    ax[2].set_title(f"Smart Cropped ({target_aspect_ratio:.2f})")
    ax[2].axis("off")
    
    plt.tight_layout()
    
    # Create a temporary file for visualization
    vis_path = "visualization.png"
    plt.savefig(vis_path)
    plt.close()
    
    return cropped_img, vis_path

def aspect_ratio_options(choice):
    """Map aspect ratio choices to actual values"""
    options = {
        "16:9 (Landscape)": 16/9,
        "9:16 (Portrait)": 9/16,
        "4:3 (Standard)": 4/3,
        "3:4 (Portrait)": 3/4,
        "1:1 (Square)": 1/1,
        "21:9 (Ultrawide)": 21/9,
        "2:3 (Portrait)": 2/3,
        "3:2 (Landscape)": 3/2,
    }
    return options.get(choice, 16/9)

def process_image(input_image, aspect_ratio_choice, point_x=None, point_y=None):
    if input_image is None:
        return None, None
    
    # Get the actual aspect ratio value
    target_aspect_ratio = aspect_ratio_options(aspect_ratio_choice)
    
    # Process the image
    result_img, vis_path = smart_crop(input_image, target_aspect_ratio, point_x, point_y)
    
    return result_img, vis_path

def create_app():
    with gr.Blocks(title="Smart Image Cropper using SAM") as app:
        gr.Markdown("# Smart Image Cropper using Segment Anything Model (SAM)")
        gr.Markdown("""
        Upload an image and choose your target aspect ratio. The app will use the Segment Anything Model (SAM) 
        to identify important content and crop intelligently to preserve it.
        
        Optionally, you can click on the uploaded image to specify a point of interest.
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image(type="pil", label="Upload Image")
                aspect_ratio = gr.Dropdown(
                    choices=[
                        "16:9 (Landscape)",
                        "9:16 (Portrait)",
                        "4:3 (Standard)",
                        "3:4 (Portrait)",
                        "1:1 (Square)",
                        "21:9 (Ultrawide)",
                        "2:3 (Portrait)",
                        "3:2 (Landscape)"
                    ],
                    value="16:9 (Landscape)",
                    label="Target Aspect Ratio"
                )
                point_coords = gr.State(value=[None, None])
                
                def update_coords(img, evt: gr.SelectData):
                    return [evt.index[0], evt.index[1]]
                
                input_image.select(update_coords, inputs=[input_image], outputs=[point_coords])
                
                process_btn = gr.Button("Process Image")
            
            with gr.Column(scale=2):
                output_image = gr.Image(type="pil", label="Cropped Result")
                visualization = gr.Image(type="filepath", label="Process Visualization")
        
        process_btn.click(
            fn=lambda img, ratio, coords: process_image(img, ratio, coords[0], coords[1]),
            inputs=[input_image, aspect_ratio, point_coords],
            outputs=[output_image, visualization]
        )
        
        gr.Markdown("""
        ## How It Works
        
        1. The Segment Anything Model (SAM) analyzes your image to identify the important content
        2. The app finds the optimal crop window that maximizes the preservation of that content
        3. The image is cropped to your desired aspect ratio while keeping the important parts
        
        ## Tips
        
        - For better results with specific subjects, click on the important object in the image
        - Try different aspect ratios to see how the model adapts the cropping
        """)
    
    return app

# Create and launch the app
demo = create_app()

# For local testing
if __name__ == "__main__":
    demo.launch()
else:
    # For Hugging Face Spaces
    demo.launch()