import gradio as gr import numpy as np from PIL import Image, ImageOps from transformers import DetrImageProcessor, DetrForObjectDetection import torch feature_extractor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101") dmodel = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101") i1 = gr.inputs.Image(type="pil", label="Input image") i2 = gr.inputs.Number(default=400, label="Custom Width") i3 = gr.inputs.Number(default=400, label="Custom Height") o1 = gr.outputs.Image(type="pil", label="Cropped part") def extract_image(image, custom_width, custom_height): inputs = feature_extractor(images=image, return_tensors="pt") outputs = dmodel(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] # Count the number of objects in each area object_counts = {} for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): area_key = (int(box[0] / 100) * 100, int(box[1] / 100) * 100) # Group by areas object_counts[area_key] = object_counts.get(area_key, 0) + 1 # Find the area with the most detected objects most_objects_area = max(object_counts, key=object_counts.get) # Calculate the center of the area with most objects center_x = most_objects_area[0] + custom_width / 2 center_y = most_objects_area[1] + custom_height / 2 # Adjust cropping coordinates to centralize the area xmin = int((center_x - custom_width / 2)- 10) ymin = int((center_y - custom_height / 2)- 10) xmax = int((center_x + custom_width / 2)+ 10) ymax = int((center_y + custom_height / 2)+ 10) # Apply a bleed of at least 10 pixels on all sides xmin = max(0, xmin) ymin = max(0, ymin) xmax = min(image.width, xmax ) ymax = min(image.height, ymax) cropped_image = image.crop((xmin, ymin, xmax, ymax)) return cropped_image title = "Auto Crop" description = "
Crop an image with the area containing the most detected objects.
" examples = [['ex3.jpg', 800, 400], ['cat.png', 400, 400]] gr.Interface(fn=extract_image, inputs=[i1, i2, i3], outputs=[o1], title=title, description=description, examples=examples, enable_queue=True).launch()