mostlycached commited on
Commit
f9efdb7
·
verified ·
1 Parent(s): b84e980

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -78
app.py CHANGED
@@ -3,53 +3,71 @@ import numpy as np
3
  import cv2
4
  from PIL import Image
5
  import torch
6
- from transformers import AutoImageProcessor, AutoModelForObjectDetection
 
7
 
8
  # Set up device
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  print(f"Using device: {device}")
11
 
12
- # Load object detection model for identifying important content
13
- print("Loading object detection model...")
14
- processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
15
- model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
16
 
17
- def detect_objects(image):
18
- """Detect important objects in the image using DETR"""
19
- # Convert to PIL if needed
20
- if not isinstance(image, Image.Image):
21
- image_pil = Image.fromarray(image)
22
- else:
23
  image_pil = image
24
- image = np.array(image_pil)
25
-
26
- # Get image dimensions
27
- h, w = image.shape[:2]
28
-
29
- # Process image for object detection
30
- inputs = processor(images=image_pil, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  with torch.no_grad():
32
- outputs = model(**inputs)
33
-
34
- # Convert outputs to usable format
35
- target_sizes = torch.tensor([image_pil.size[::-1]]).to(device)
36
- results = processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]
 
 
 
 
 
37
 
38
- # Store detected objects
39
- detected_boxes = []
 
 
40
 
41
- # For each detected object
42
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
43
- box = box.cpu().numpy().astype(int)
44
- detected_boxes.append({
45
- 'box': box,
46
- 'score': score.item(),
47
- 'label': model.config.id2label[label.item()]
48
- })
49
 
50
- return detected_boxes
51
 
52
- def find_optimal_crop(image, target_ratio, objects):
53
  """Find the optimal crop area that preserves important content while matching target ratio"""
54
  # Get image dimensions
55
  if not isinstance(image, np.ndarray):
@@ -59,34 +77,6 @@ def find_optimal_crop(image, target_ratio, objects):
59
  current_ratio = w / h
60
  target_ratio_value = eval(target_ratio.replace(':', '/'))
61
 
62
- # If no objects detected, use center crop
63
- if not objects:
64
- if current_ratio > target_ratio_value:
65
- # Need to crop width
66
- new_width = int(h * target_ratio_value)
67
- left = (w - new_width) // 2
68
- right = left + new_width
69
- return (left, 0, right, h)
70
- else:
71
- # Need to crop height
72
- new_height = int(w / target_ratio_value)
73
- top = (h - new_height) // 2
74
- bottom = top + new_height
75
- return (0, top, w, bottom)
76
-
77
- # Create a combined importance map from all detected objects
78
- importance_map = np.zeros((h, w), dtype=np.float32)
79
-
80
- # Add all objects to the importance map
81
- for obj in objects:
82
- x1, y1, x2, y2 = obj['box']
83
- # Ensure box is within image boundaries
84
- x1, y1 = max(0, x1), max(0, y1)
85
- x2, y2 = min(w-1, x2), min(h-1, y2)
86
-
87
- # Add object to importance map with its confidence score
88
- importance_map[y1:y2, x1:x2] = max(importance_map[y1:y2, x1:x2], obj['score'])
89
-
90
  # If current ratio is wider than target, we need to crop width
91
  if current_ratio > target_ratio_value:
92
  new_width = int(h * target_ratio_value)
@@ -144,16 +134,16 @@ def apply_crop(image, crop_box):
144
 
145
  def adjust_aspect_ratio(image, target_ratio):
146
  """Main function to adjust aspect ratio through intelligent cropping"""
147
- # Detect objects in the image
148
- objects = detect_objects(image)
149
 
150
  # Find optimal crop box
151
- crop_box = find_optimal_crop(image, target_ratio, objects)
152
 
153
  # Apply the crop
154
  result = apply_crop(image, crop_box)
155
 
156
- return result
157
 
158
  def process_image(input_image, target_ratio="16:9"):
159
  """Process function for Gradio interface"""
@@ -165,7 +155,7 @@ def process_image(input_image, target_ratio="16:9"):
165
  image = input_image
166
 
167
  # Adjust aspect ratio
168
- result = adjust_aspect_ratio(image, target_ratio)
169
 
170
  # Convert result to appropriate format
171
  if isinstance(result, np.ndarray):
@@ -173,15 +163,26 @@ def process_image(input_image, target_ratio="16:9"):
173
  else:
174
  result_pil = result
175
 
176
- return result_pil
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  except Exception as e:
179
  print(f"Error processing image: {e}")
180
- return None
181
 
182
  # Create the Gradio interface
183
- with gr.Blocks(title="Smart Crop Aspect Ratio Adjuster") as demo:
184
- gr.Markdown("# Smart Crop Aspect Ratio Adjuster")
185
  gr.Markdown("Upload an image, choose your target aspect ratio, and the AI will intelligently crop it to preserve important content.")
186
 
187
  with gr.Row():
@@ -199,23 +200,24 @@ with gr.Blocks(title="Smart Crop Aspect Ratio Adjuster") as demo:
199
 
200
  with gr.Column():
201
  output_image = gr.Image(label="Processed Image")
 
202
 
203
  submit_btn.click(
204
  process_image,
205
  inputs=[input_image, aspect_ratio],
206
- outputs=output_image
207
  )
208
 
209
  gr.Markdown("""
210
  ## How it works
211
- 1. **Object Detection**: The app uses a DETR (DEtection TRansformer) model to identify important objects in your image
212
- 2. **Importance Mapping**: It creates an importance map based on detected objects
213
- 3. **Smart Cropping**: The algorithm finds the optimal crop window that preserves the most important content
214
 
215
  ## Tips
216
- - For best results, ensure important subjects are visible and not too close to the edges
 
217
  - Try different aspect ratios to see what works best with your image
218
- - The model works best with clear, well-lit images with distinct objects
219
  """)
220
 
221
  # Launch the app
 
3
  import cv2
4
  from PIL import Image
5
  import torch
6
+ from transformers import SamModel, SamProcessor
7
+ import os
8
 
9
  # Set up device
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  print(f"Using device: {device}")
12
 
13
+ # Load SAM model for segmentation
14
+ print("Loading SAM model...")
15
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
16
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
17
 
18
+ def get_sam_masks(image):
19
+ """Get segmentation masks using SAM model"""
20
+ # Convert to numpy if needed
21
+ if isinstance(image, Image.Image):
 
 
22
  image_pil = image
23
+ image_np = np.array(image)
24
+ else:
25
+ image_np = image
26
+ image_pil = Image.fromarray(image_np)
27
+
28
+ h, w = image_np.shape[:2]
29
+
30
+ # Create a grid of points to sample the image
31
+ x_points = np.linspace(w//4, 3*w//4, 5, dtype=int)
32
+ y_points = np.linspace(h//4, 3*h//4, 5, dtype=int)
33
+ grid_points = []
34
+ for y in y_points:
35
+ for x in x_points:
36
+ grid_points.append([x, y])
37
+ points = [grid_points]
38
+
39
+ # Process image through SAM
40
+ inputs = sam_processor(
41
+ images=image_pil,
42
+ input_points=points,
43
+ return_tensors="pt"
44
+ ).to(device)
45
+
46
+ # Generate masks
47
  with torch.no_grad():
48
+ outputs = sam_model(**inputs)
49
+ masks = sam_processor.image_processor.post_process_masks(
50
+ outputs.pred_masks.cpu(),
51
+ inputs["original_sizes"].cpu(),
52
+ inputs["reshaped_input_sizes"].cpu()
53
+ )
54
+
55
+ # Combine all masks to create importance map
56
+ importance_map = np.zeros((h, w), dtype=np.float32)
57
+ individual_masks = []
58
 
59
+ for i in range(len(masks[0])):
60
+ mask = masks[0][i].numpy().astype(np.float32)
61
+ individual_masks.append(mask)
62
+ importance_map += mask
63
 
64
+ # Normalize to 0-1
65
+ if importance_map.max() > 0:
66
+ importance_map = importance_map / importance_map.max()
 
 
 
 
 
67
 
68
+ return importance_map, individual_masks
69
 
70
+ def find_optimal_crop(image, target_ratio, importance_map):
71
  """Find the optimal crop area that preserves important content while matching target ratio"""
72
  # Get image dimensions
73
  if not isinstance(image, np.ndarray):
 
77
  current_ratio = w / h
78
  target_ratio_value = eval(target_ratio.replace(':', '/'))
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # If current ratio is wider than target, we need to crop width
81
  if current_ratio > target_ratio_value:
82
  new_width = int(h * target_ratio_value)
 
134
 
135
  def adjust_aspect_ratio(image, target_ratio):
136
  """Main function to adjust aspect ratio through intelligent cropping"""
137
+ # Get segmentation masks and importance map
138
+ importance_map, _ = get_sam_masks(image)
139
 
140
  # Find optimal crop box
141
+ crop_box = find_optimal_crop(image, target_ratio, importance_map)
142
 
143
  # Apply the crop
144
  result = apply_crop(image, crop_box)
145
 
146
+ return result, importance_map
147
 
148
  def process_image(input_image, target_ratio="16:9"):
149
  """Process function for Gradio interface"""
 
155
  image = input_image
156
 
157
  # Adjust aspect ratio
158
+ result, importance_map = adjust_aspect_ratio(image, target_ratio)
159
 
160
  # Convert result to appropriate format
161
  if isinstance(result, np.ndarray):
 
163
  else:
164
  result_pil = result
165
 
166
+ # Visualize importance map for debugging
167
+ if isinstance(importance_map, np.ndarray):
168
+ # Convert to heatmap
169
+ heatmap = (importance_map * 255).astype(np.uint8)
170
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
171
+
172
+ # Convert to PIL
173
+ heatmap_pil = Image.fromarray(heatmap)
174
+
175
+ return [result_pil, heatmap_pil]
176
+
177
+ return [result_pil, None]
178
 
179
  except Exception as e:
180
  print(f"Error processing image: {e}")
181
+ return [None, None]
182
 
183
  # Create the Gradio interface
184
+ with gr.Blocks(title="SAM-Based Smart Crop Aspect Ratio Adjuster") as demo:
185
+ gr.Markdown("# SAM-Based Smart Crop Aspect Ratio Adjuster")
186
  gr.Markdown("Upload an image, choose your target aspect ratio, and the AI will intelligently crop it to preserve important content.")
187
 
188
  with gr.Row():
 
200
 
201
  with gr.Column():
202
  output_image = gr.Image(label="Processed Image")
203
+ importance_map_vis = gr.Image(label="Importance Map (Debug View)")
204
 
205
  submit_btn.click(
206
  process_image,
207
  inputs=[input_image, aspect_ratio],
208
+ outputs=[output_image, importance_map_vis]
209
  )
210
 
211
  gr.Markdown("""
212
  ## How it works
213
+ 1. **Segmentation**: Uses Meta's Segment Anything Model (SAM) to identify important regions in your image
214
+ 2. **Importance Mapping**: Creates a heatmap of important areas based on segmentation masks
215
+ 3. **Smart Cropping**: Finds the optimal crop window that preserves the most important content
216
 
217
  ## Tips
218
+ - For best results, ensure important subjects are clearly visible in the image
219
+ - The importance map shows what the AI considers important (red/yellow = important, blue = less important)
220
  - Try different aspect ratios to see what works best with your image
 
221
  """)
222
 
223
  # Launch the app