CLiPcrop / app.py
h4d35's picture
Create app.py
67a02dd
raw history blame
No virus
2.98 kB
import gradio as gr
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel, DetrFeatureExtractor, DetrForObjectDetection, AutoFeatureExtractor, AutoModelForObjectDetection
import torch
feature_extractor = AutoFeatureExtractor.from_pretrained("nielsr/detr-resnet-50")
dmodel = AutoModelForObjectDetection.from_pretrained("nielsr/detr-resnet-50")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
i1 = gr.inputs.Image(type="pil", label="Input image")
i2 = gr.inputs.Textbox(label="Input text")
i3 = gr.inputs.Number(default=0.96, label="Threshold percentage score")
o1 = gr.outputs.Image(type="pil", label="Cropped part")
o2 = gr.outputs.Textbox(label="Similarity score")
def extract_image(image, text, prob, num=1):
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = dmodel(**inputs)
# model predicts bounding boxes and corresponding COCO classes
logits = outputs.logits
bboxes = outputs.pred_boxes
probas = outputs.logits.softmax(-1)[0, :, :-1] #removing no class as detr maps
keep = probas.max(-1).values > prob
outs = feature_extractor.post_process(outputs, torch.tensor(image.size[::-1]).unsqueeze(0))
bboxes_scaled = outs[0]['boxes'][keep].detach().numpy()
labels = outs[0]['labels'][keep].detach().numpy()
scores = outs[0]['scores'][keep].detach().numpy()
images_list = []
for i,j in enumerate(bboxes_scaled):
xmin = int(j[0])
ymin = int(j[1])
xmax = int(j[2])
ymax = int(j[3])
im_arr = np.array(image)
roi = im_arr[ymin:ymax, xmin:xmax]
roi_im = Image.fromarray(roi)
images_list.append(roi_im)
inpu = processor(text = [text], images=images_list , return_tensors="pt", padding=True)
output = model(**inpu)
logits_per_image = output.logits_per_text
probs = logits_per_image.softmax(-1)
l_idx = np.argsort(probs[-1].detach().numpy())[::-1][0:num]
final_ims = []
for i,j in enumerate(images_list):
json_dict = {}
if i in l_idx:
json_dict['image'] = images_list[i]
json_dict['score'] = probs[-1].detach().numpy()[i]
final_ims.append(json_dict)
fi = sorted(final_ims, key=lambda item: item.get("score"), reverse=True)
return fi[0]['image'], fi[0]['score']
title = "ClipnCrop"
description = "Extract sections of images from your image by using OpenAI's CLIP and Facebooks Detr implemented on HuggingFace Transformers"
examples=[['ex3.jpg', 'black bag', 0.96],['ex2.jpg', 'man in red dress', 0.85]]
article = "<p style='text-align: center'><a href='https://github.com/Vishnunkumar/clipcrop' target='_blank'>clipcrop</a></p>"
gr.Interface(fn=extract_image, inputs=[i1, i2, i3], outputs=[o1, o2], title=title, description=description, article=article, examples=examples, enable_queue=True).launch()