import torch from PIL import Image from PIL import ImageDraw def encode_scene(obj_list, H=320, W=320, src_bbox_format='xywh', tgt_bbox_format='xyxy'): """Encode scene into text and bounding boxes Args: obj_list: list of dicts Each dict has keys: 'color': str 'material': str 'shape': str or 'caption': str and 'bbox': list of 4 floats (unnormalized) [x0, y0, x1, y1] or [x0, y0, w, h] """ box_captions = [] for obj in obj_list: if 'caption' in obj: box_caption = obj['caption'] else: box_caption = f"{obj['color']} {obj['material']} {obj['shape']}" box_captions += [box_caption] assert src_bbox_format in ['xywh', 'xyxy'], f"src_bbox_format must be 'xywh' or 'xyxy', not {src_bbox_format}" assert tgt_bbox_format in ['xywh', 'xyxy'], f"tgt_bbox_format must be 'xywh' or 'xyxy', not {tgt_bbox_format}" boxes_unnormalized = [] boxes_normalized = [] for obj in obj_list: if src_bbox_format == 'xywh': x0, y0, w, h = obj['bbox'] x1 = x0 + w y1 = y0 + h elif src_bbox_format == 'xyxy': x0, y0, x1, y1 = obj['bbox'] w = x1 - x0 h = y1 - y0 assert x1 > x0, f"x1={x1} <= x0={x0}" assert y1 > y0, f"y1={y1} <= y0={y0}" assert x1 <= W, f"x1={x1} > W={W}" assert y1 <= H, f"y1={y1} > H={H}" if tgt_bbox_format == 'xywh': bbox_unnormalized = [x0, y0, w, h] bbox_normalized = [x0 / W, y0 / H, w / W, h / H] elif tgt_bbox_format == 'xyxy': bbox_unnormalized = [x0, y0, x1, y1] bbox_normalized = [x0 / W, y0 / H, x1 / W, y1 / H] boxes_unnormalized += [bbox_unnormalized] boxes_normalized += [bbox_normalized] assert len(box_captions) == len(boxes_normalized), f"len(box_captions)={len(box_captions)} != len(boxes_normalized)={len(boxes_normalized)}" out = {} out['box_captions'] = box_captions out['boxes_normalized'] = boxes_normalized out['boxes_unnormalized'] = boxes_unnormalized return out def encode_from_custom_annotation(custom_annotations, size=512): # custom_annotations = [ # {'x': 83, 'y': 335, 'width': 70, 'height': 69, 'label': 'blue metal cube'}, # {'x': 162, 'y': 302, 'width': 110, 'height': 138, 'label': 'blue metal cube'}, # {'x': 274, 'y': 250, 'width': 191, 'height': 234, 'label': 'blue metal cube'}, # {'x': 14, 'y': 18, 'width': 155, 'height': 205, 'label': 'blue metal cube'}, # {'x': 175, 'y': 79, 'width': 106, 'height': 119, 'label': 'blue metal cube'}, # {'x': 288, 'y': 111, 'width': 69, 'height': 63, 'label': 'blue metal cube'} # ] H, W = size, size objects = [] for j in range(len(custom_annotations)): xyxy = [ custom_annotations[j]['x'], custom_annotations[j]['y'], custom_annotations[j]['x'] + custom_annotations[j]['width'], custom_annotations[j]['y'] + custom_annotations[j]['height']] objects.append({ 'caption': custom_annotations[j]['label'], 'bbox': xyxy, }) out = encode_scene(objects, H=H, W=W, src_bbox_format='xyxy', tgt_bbox_format='xyxy') return out #### Below are for HF diffusers def iterinpaint_sample_diffusers(pipe, datum, paste=True, verbose=False, guidance_scale=4.0, size=512, background_instruction='Add gray background'): d = datum d['unnormalized_boxes'] = d['boxes_unnormalized'] n_total_boxes = len(d['unnormalized_boxes']) context_imgs = [] mask_imgs = [] # masked_imgs = [] generated_images = [] prompts = [] context_img = Image.new('RGB', (size, size)) # context_draw = ImageDraw.Draw(context_img) if verbose: print('Initiailzed context image') background_mask_img = Image.new('L', (size, size)) background_mask_draw = ImageDraw.Draw(background_mask_img) background_mask_draw.rectangle([(0, 0), background_mask_img.size], fill=255) for i in range(n_total_boxes): if verbose: print('Iter: ', i+1, 'total: ', n_total_boxes) target_caption = d['box_captions'][i] if verbose: print('Drawing ', target_caption) mask_img = Image.new('L', context_img.size) mask_draw = ImageDraw.Draw(mask_img) mask_draw.rectangle([(0, 0), mask_img.size], fill=0) box = d['unnormalized_boxes'][i] if type(box) == list: box = torch.tensor(box) mask_draw.rectangle(box.long().tolist(), fill=255) background_mask_draw.rectangle(box.long().tolist(), fill=0) mask_imgs.append(mask_img.copy()) prompt = f"Add {d['box_captions'][i]}" if verbose: print('prompt:', prompt) prompts += [prompt] context_imgs.append(context_img.copy()) generated_image = pipe( prompt, context_img, mask_img, guidance_scale=guidance_scale).images[0] if paste: # context_img.paste(generated_image.crop(box.long().tolist()), box.long().tolist()) src_box = box.long().tolist() # x1 -> x1 + 1 # y1 -> y1 + 1 paste_box = box.long().tolist() paste_box[0] -= 1 paste_box[1] -= 1 paste_box[2] += 1 paste_box[3] += 1 box_w = paste_box[2] - paste_box[0] box_h = paste_box[3] - paste_box[1] context_img.paste(generated_image.crop(src_box).resize((box_w, box_h)), paste_box) generated_images.append(context_img.copy()) else: context_img = generated_image generated_images.append(context_img.copy()) if verbose: print('Fill background') mask_img = background_mask_img mask_imgs.append(mask_img) prompt = background_instruction if verbose: print('prompt:', prompt) prompts += [prompt] generated_image = pipe( prompt, context_img, mask_img, guidance_scale=guidance_scale).images[0] generated_images.append(generated_image) return { 'context_imgs': context_imgs, 'mask_imgs': mask_imgs, 'prompts': prompts, 'generated_images': generated_images, }