StKirill's picture
Update app.py
f4d7772
import gradio as gr
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
import PIL
import gradio as gr
from PIL import Image, ImageDraw
import requests
# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", revision="no_timm")
def biggest_obj(res):
max_area = 0
for i, bb in enumerate(res["boxes"]):
x1,y1,x2,y2 = list(map(int, bb.tolist()))
area = (abs(x2-x1)*abs(y1-y2))
if area > max_area:
max_area = area
ind = i
coords = list(map(int, bb.tolist()))
cl = model.config.id2label[res["labels"][ind].item()]
return ind, coords, cl
def create_mask(im_shape:tuple, mask_zone:list):
mask = Image.new("L", im_shape, 0)
draw = ImageDraw.Draw(mask)
draw.rectangle(mask_zone, fill=255)
return mask
from diffusers import StableDiffusionInpaintPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
).to(device)
def predict(image, prompt):
image = image.convert("RGB").resize((512, 512))
# DETR works
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.9
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# find the biggest bb on the image
ind, coords, cl = biggest_obj(results)
# mask image
mask_image = create_mask(image.size, coords)
images = pipe(
prompt=prompt,
image=image,
mask_image=mask_image,
guidance_scale=5,
generator=torch.Generator(device="cuda").manual_seed(0),
num_images_per_prompt=1,
).images
draw_on_image = ImageDraw.Draw(image)
# Define the rectangle coordinates (left-top, right-bottom)
rectangle_coordinates = coords
draw_on_image.rectangle(rectangle_coordinates, outline="red", width=2)
return images[0], image
examples = [["cats.png", "cat is smiling"],
["dog.jpg", "dog with big eyes"],
["dog1.jpg", "dog with big bone"],
["beaver.jpg", "big strong beaver"]]
gr.Interface(
predict,
title = 'Stable Diffusion In-Painting',
inputs=[
gr.Image(type = 'pil'),
gr.Textbox(label = 'prompt')
],
outputs = [
gr.Image(),
gr.Image(),
],
examples=examples,
).launch(debug=True, share=True)