IterInpaint-CLEVR / gen_utils.py
j-min's picture
Initial commit
342816e
import torch
from PIL import Image
from PIL import ImageDraw
def encode_scene(obj_list, H=320, W=320, src_bbox_format='xywh', tgt_bbox_format='xyxy'):
"""Encode scene into text and bounding boxes
Args:
obj_list: list of dicts
Each dict has keys:
'color': str
'material': str
'shape': str
or
'caption': str
and
'bbox': list of 4 floats (unnormalized)
[x0, y0, x1, y1] or [x0, y0, w, h]
"""
box_captions = []
for obj in obj_list:
if 'caption' in obj:
box_caption = obj['caption']
else:
box_caption = f"{obj['color']} {obj['material']} {obj['shape']}"
box_captions += [box_caption]
assert src_bbox_format in ['xywh', 'xyxy'], f"src_bbox_format must be 'xywh' or 'xyxy', not {src_bbox_format}"
assert tgt_bbox_format in ['xywh', 'xyxy'], f"tgt_bbox_format must be 'xywh' or 'xyxy', not {tgt_bbox_format}"
boxes_unnormalized = []
boxes_normalized = []
for obj in obj_list:
if src_bbox_format == 'xywh':
x0, y0, w, h = obj['bbox']
x1 = x0 + w
y1 = y0 + h
elif src_bbox_format == 'xyxy':
x0, y0, x1, y1 = obj['bbox']
w = x1 - x0
h = y1 - y0
assert x1 > x0, f"x1={x1} <= x0={x0}"
assert y1 > y0, f"y1={y1} <= y0={y0}"
assert x1 <= W, f"x1={x1} > W={W}"
assert y1 <= H, f"y1={y1} > H={H}"
if tgt_bbox_format == 'xywh':
bbox_unnormalized = [x0, y0, w, h]
bbox_normalized = [x0 / W, y0 / H, w / W, h / H]
elif tgt_bbox_format == 'xyxy':
bbox_unnormalized = [x0, y0, x1, y1]
bbox_normalized = [x0 / W, y0 / H, x1 / W, y1 / H]
boxes_unnormalized += [bbox_unnormalized]
boxes_normalized += [bbox_normalized]
assert len(box_captions) == len(boxes_normalized), f"len(box_captions)={len(box_captions)} != len(boxes_normalized)={len(boxes_normalized)}"
out = {}
out['box_captions'] = box_captions
out['boxes_normalized'] = boxes_normalized
out['boxes_unnormalized'] = boxes_unnormalized
return out
def encode_from_custom_annotation(custom_annotations, size=512):
# custom_annotations = [
# {'x': 83, 'y': 335, 'width': 70, 'height': 69, 'label': 'blue metal cube'},
# {'x': 162, 'y': 302, 'width': 110, 'height': 138, 'label': 'blue metal cube'},
# {'x': 274, 'y': 250, 'width': 191, 'height': 234, 'label': 'blue metal cube'},
# {'x': 14, 'y': 18, 'width': 155, 'height': 205, 'label': 'blue metal cube'},
# {'x': 175, 'y': 79, 'width': 106, 'height': 119, 'label': 'blue metal cube'},
# {'x': 288, 'y': 111, 'width': 69, 'height': 63, 'label': 'blue metal cube'}
# ]
H, W = size, size
objects = []
for j in range(len(custom_annotations)):
xyxy = [
custom_annotations[j]['x'],
custom_annotations[j]['y'],
custom_annotations[j]['x'] + custom_annotations[j]['width'],
custom_annotations[j]['y'] + custom_annotations[j]['height']]
objects.append({
'caption': custom_annotations[j]['label'],
'bbox': xyxy,
})
out = encode_scene(objects, H=H, W=W,
src_bbox_format='xyxy', tgt_bbox_format='xyxy')
return out
#### Below are for HF diffusers
def iterinpaint_sample_diffusers(pipe, datum, paste=True, verbose=False, guidance_scale=4.0, size=512, background_instruction='Add gray background'):
d = datum
d['unnormalized_boxes'] = d['boxes_unnormalized']
n_total_boxes = len(d['unnormalized_boxes'])
context_imgs = []
mask_imgs = []
# masked_imgs = []
generated_images = []
prompts = []
context_img = Image.new('RGB', (size, size))
# context_draw = ImageDraw.Draw(context_img)
if verbose:
print('Initiailzed context image')
background_mask_img = Image.new('L', (size, size))
background_mask_draw = ImageDraw.Draw(background_mask_img)
background_mask_draw.rectangle([(0, 0), background_mask_img.size], fill=255)
for i in range(n_total_boxes):
if verbose:
print('Iter: ', i+1, 'total: ', n_total_boxes)
target_caption = d['box_captions'][i]
if verbose:
print('Drawing ', target_caption)
mask_img = Image.new('L', context_img.size)
mask_draw = ImageDraw.Draw(mask_img)
mask_draw.rectangle([(0, 0), mask_img.size], fill=0)
box = d['unnormalized_boxes'][i]
if type(box) == list:
box = torch.tensor(box)
mask_draw.rectangle(box.long().tolist(), fill=255)
background_mask_draw.rectangle(box.long().tolist(), fill=0)
mask_imgs.append(mask_img.copy())
prompt = f"Add {d['box_captions'][i]}"
if verbose:
print('prompt:', prompt)
prompts += [prompt]
context_imgs.append(context_img.copy())
generated_image = pipe(
prompt,
context_img,
mask_img,
guidance_scale=guidance_scale).images[0]
if paste:
# context_img.paste(generated_image.crop(box.long().tolist()), box.long().tolist())
src_box = box.long().tolist()
# x1 -> x1 + 1
# y1 -> y1 + 1
paste_box = box.long().tolist()
paste_box[0] -= 1
paste_box[1] -= 1
paste_box[2] += 1
paste_box[3] += 1
box_w = paste_box[2] - paste_box[0]
box_h = paste_box[3] - paste_box[1]
context_img.paste(generated_image.crop(src_box).resize((box_w, box_h)), paste_box)
generated_images.append(context_img.copy())
else:
context_img = generated_image
generated_images.append(context_img.copy())
if verbose:
print('Fill background')
mask_img = background_mask_img
mask_imgs.append(mask_img)
prompt = background_instruction
if verbose:
print('prompt:', prompt)
prompts += [prompt]
generated_image = pipe(
prompt,
context_img,
mask_img,
guidance_scale=guidance_scale).images[0]
generated_images.append(generated_image)
return {
'context_imgs': context_imgs,
'mask_imgs': mask_imgs,
'prompts': prompts,
'generated_images': generated_images,
}