import gradio as gr
import torch
import clip
from PIL import Image
import numpy as np
device = "cpu"
model, preprocess = clip.load("RN50x64", device=device)
def img_process(img1,img2,location_width,location_height,size_width,size_height):
im1=Image.open(img1)
im2=Image.open(img2).convert('RGBA').resize((600,400))
print(im1.mode)
if im1.mode == 'RGBA':
size = im1.size
im3 = im1.resize((int(size[0]/2),int(size[1]/2)))
r, g, b, a = im3.split()
im2.paste(im3,(50, 50), mask=a)
elif im1.mode == 'RGB':
threshold=240
size = im1.size
im1 = im1.resize((size_width,size_height))
im1=im1.convert('RGBA')
arr=np.array(np.asarray(im1))
r,g,b,a=np.rollaxis(arr,axis=-1)
mask=((r>threshold)
& (g>threshold)
& (b>threshold)
)
arr[mask,3]=0
im1=Image.fromarray(arr,mode='RGBA')
r, g, b, a = im1.split()
im2.paste(im1,(location_width,location_height,), mask=a)
return im2
def itm(obj,back,location_width,location_height,size_width,size_height,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr):
img1 = img_process(obj,back,location_width,location_height,size_width,size_height)
img = preprocess(img1).unsqueeze(0)
obj_prompt = neg_obj if is_obj else pos_obj
attr_prompt = neg_attr if is_attr else pos_attr
text = clip.tokenize([f"a photo of {pos_attr} {pos_obj}",f"a photo of {attr_prompt} {obj_prompt}"])
with torch.no_grad():
logits_per_image, logits_per_text = model(img, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs)
return f"a photo of {pos_attr} {pos_obj}",probs[0][0],f"a photo of {attr_prompt} {obj_prompt}",probs[0][1],img1
with gr.Blocks() as demo:
gr.Markdown("
VL-Checklist Demo
")
gr.Markdown("""
Tips:
- In this demo, you can change the object and attribute of object in the text prompt, and you can also change the size and location of the object.
- Please upload an object image with white background.
- The model we used in the demo is CLIP.
""")
with gr.Row():
with gr.Column():
img_obj = gr.Image(value ='sample/apple.png',type = "filepath",label='object_img(Plz input an object with white background)')
loc_w = gr.Slider(maximum = 500,label='location_width',step=1)
loc_h = gr.Slider(maximum = 300,label='location_height',step=1)
s_w = gr.Number(value =200,precision=0,label='size_width')
s_h = gr.Number(value =200,precision=0,label='size_height')
gr.Markdown("Click **Submit** to get the output!")
with gr.Column():
img_back = gr.Image(value ='sample/back.jpg',type = "filepath",label='background_img')
is_obj = gr.Checkbox(value = True,label='Does negative prompt change the object?')
pos_obj = gr.Textbox(value = 'apple',label='positive object')
neg_obj = gr.Textbox(value = 'dog',label='negative object')
is_attr = gr.Checkbox(value = False,label='Does negative prompt change the attribute?')
pos_attr = gr.Textbox(value = 'red',label='positive attribute')
neg_attr = gr.Textbox(value = 'green',label='negative attribute')
with gr.Row():
btn = gr.Button("Submit",variant="primary")
with gr.Row():
with gr.Column():
img_output = gr.Image(type = "pil",label='output_img')
with gr.Column():
pos_prom = gr.Textbox(label='Positive prompt')
pos_s = gr.Textbox(label='Positive score')
neg_prom = gr.Textbox(label='Negative prompt')
neg_s = gr.Textbox(label='Negative score')
with gr.Row():
gr.Examples([['sample/apple.png', 'sample/back.jpg',50,50,200,200,True,'apple','dog',False,'red','green'],
['sample/banana.jpg', 'sample/back.jpg',300,200,200,200,True,'bananas','peaches',False,'yellow','green']],
[img_obj,img_back,loc_w,loc_h,s_w,s_h,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr],
[pos_prom,pos_s,neg_prom,neg_s,img_output],itm,True)
btn.click(fn=itm,inputs=[img_obj,img_back,loc_w,loc_h,s_w,s_h,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr],
outputs=[pos_prom,pos_s,neg_prom,neg_s,img_output],
)
demo.launch()