detr101crop / app.py
kinsung's picture
img
9daaec3
raw
history blame
2.6 kB
import gradio as gr
import numpy as np
from PIL import Image, ImageOps
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
feature_extractor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101")
dmodel = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101")
i1 = gr.inputs.Image(type="pil", label="Input image")
i2 = gr.inputs.Number(default=400, label="Custom Width (optional)")
i3 = gr.inputs.Number(default=400, label="Custom Height (optional)")
o1 = gr.outputs.Image(type="pil", label="Cropped part")
def extract_image(image, custom_width, custom_height):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = dmodel(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# Count the number of objects in each area
object_counts = {}
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
area_key = (int(box[0] / 100) * 100, int(box[1] / 100) * 100) # Group by areas
object_counts[area_key] = object_counts.get(area_key, 0) + 1
# Find the area with the most detected objects
most_objects_area = max(object_counts, key=object_counts.get)
# Calculate the center of the area with most objects
center_x = most_objects_area[0] + custom_width / 2
center_y = most_objects_area[1] + custom_height / 2
# Adjust cropping coordinates to centralize the area
xmin = int(center_x - custom_width / 2)
ymin = int(center_y - custom_height / 2)
xmax = int(center_x + custom_width / 2)
ymax = int(center_y + custom_height / 2)
# Apply a bleed of at least 10 pixels on all sides
xmin = max(0, xmin - 10)
ymin = max(0, ymin - 10)
xmax = min(image.width, xmax + 10)
ymax = min(image.height, ymax + 10)
cropped_image = image.crop((xmin, ymin, xmax, ymax))
# Return the coordinates of the cropped area
coordinates = f"xmin: {xmin}, ymin: {ymin}, xmax: {xmax}, ymax: {ymax}"
return cropped_image
title = "Social Media Crop"
description = "<p style='color:white'>Crop an image with the area containing the most detected objects while maintaining custom dimensions and adding a 10-pixel bleed. The area is centralized within the custom dimensions.</p>"
examples = [['ex3.jpg', 800, 400], ['cat.png', 400, 400]]
gr.Interface(fn=extract_image, inputs=[i1, i2, i3], outputs=[o1], title=title, description=description, examples=examples, enable_queue=True).launch()