j-min commited on
Commit
342816e
β€’
1 Parent(s): f419f95

Initial commit

Browse files
Files changed (4) hide show
  1. app.py +736 -0
  2. gen_utils.py +208 -0
  3. images/blank.png +0 -0
  4. requirements.txt +32 -0
app.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw, ImageFont
6
+
7
+ from collections import Counter
8
+ import math
9
+
10
+ from gradio import processing_utils
11
+ from typing import Optional
12
+
13
+ import warnings
14
+
15
+ from datetime import datetime
16
+
17
+ import torch
18
+ from PIL import Image
19
+ import matplotlib.pyplot as plt
20
+ from diffusers import StableDiffusionInpaintPipeline
21
+ from accelerate.utils import set_seed
22
+
23
+ class Instance:
24
+ def __init__(self, capacity = 2):
25
+ self.model_type = 'base'
26
+ self.loaded_model_list = {}
27
+ self.counter = Counter()
28
+ self.global_counter = Counter()
29
+ self.capacity = capacity
30
+
31
+ self.loaded_model = None
32
+
33
+ def _log(self, model_type, batch_size, instruction, phrase_list):
34
+ self.counter[model_type] += 1
35
+ self.global_counter[model_type] += 1
36
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
37
+ print('[{}] Current: {}, All: {}. Samples: {}, prompt: {}, phrases: {}'.format(
38
+ current_time, dict(self.counter), dict(self.global_counter), batch_size, instruction, phrase_list
39
+ ))
40
+
41
+ def get_model(self):
42
+ if self.loaded_model is None:
43
+ self.loaded_model = self.load_model()
44
+ return self.loaded_model
45
+
46
+ def load_model(self, model_id='j-min/IterInpaint-CLEVR'):
47
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id)
48
+ pipe = pipe.to('cuda')
49
+ def dummy(images, **kwargs):
50
+ return images, False
51
+ pipe.safety_checker = dummy
52
+ print("Disabled safety checker")
53
+
54
+ print("Loaded model")
55
+ return pipe
56
+
57
+ instance = Instance()
58
+
59
+ # from ldm.viz_utils import plot_results, fig2img, show_images
60
+ from gen_utils import encode_from_custom_annotation, iterinpaint_sample_diffusers
61
+
62
+ class ImageMask(gr.components.Image):
63
+ """
64
+ Sets: source="canvas", tool="sketch"
65
+ """
66
+
67
+ is_template = True
68
+
69
+ def __init__(self, **kwargs):
70
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
71
+
72
+ def preprocess(self, x):
73
+ if x is None:
74
+ return x
75
+ if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
76
+ decode_image = processing_utils.decode_base64_to_image(x)
77
+ width, height = decode_image.size
78
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
79
+ mask[..., -1] = 255
80
+ mask = self.postprocess(mask)
81
+ x = {'image': x, 'mask': mask}
82
+ return super().preprocess(x)
83
+
84
+
85
+ class Blocks(gr.Blocks):
86
+
87
+ def __init__(
88
+ self,
89
+ theme: str = "default",
90
+ analytics_enabled: Optional[bool] = None,
91
+ mode: str = "blocks",
92
+ title: str = "Gradio",
93
+ css: Optional[str] = None,
94
+ **kwargs,
95
+ ):
96
+
97
+ self.extra_configs = {
98
+ 'thumbnail': kwargs.pop('thumbnail', ''),
99
+ 'url': kwargs.pop('url', 'https://gradio.app/'),
100
+ 'creator': kwargs.pop('creator', '@teamGradio'),
101
+ }
102
+
103
+ super(Blocks, self).__init__(
104
+ theme, analytics_enabled, mode, title, css, **kwargs)
105
+ warnings.filterwarnings("ignore")
106
+
107
+ def get_config_file(self):
108
+ config = super(Blocks, self).get_config_file()
109
+
110
+ for k, v in self.extra_configs.items():
111
+ config[k] = v
112
+
113
+ return config
114
+
115
+ def draw_box(boxes=[], texts=[], img=None):
116
+ if len(boxes) == 0 and img is None:
117
+ return None
118
+
119
+ if img is None:
120
+ img = Image.new('RGB', (512, 512), (255, 255, 255))
121
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
122
+ draw = ImageDraw.Draw(img)
123
+ font = ImageFont.truetype("DejaVuSansMono.ttf", size=20)
124
+ for bid, box in enumerate(boxes):
125
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
126
+ anno_text = texts[bid]
127
+ draw.rectangle([box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]], outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4)
128
+ draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size*1.2)], anno_text, font=font, fill=(255,255,255))
129
+ return img
130
+
131
+ def get_concat(ims):
132
+ if len(ims) == 1:
133
+ n_col = 1
134
+ else:
135
+ n_col = 2
136
+ n_row = math.ceil(len(ims) / 2)
137
+ dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white")
138
+ for i, im in enumerate(ims):
139
+ row_id = i // n_col
140
+ col_id = i % n_col
141
+ dst.paste(im, (im.width * col_id, im.height * row_id))
142
+ return dst
143
+
144
+
145
+ def inference(language_instruction, grounding_texts, boxes, guidance_scale):
146
+
147
+ # custom_annotations = [
148
+ # {'x': 19,
149
+ # 'y': 61,
150
+ # 'width': 158,
151
+ # 'height': 169,
152
+ # 'label': 'blue metal cube'},
153
+ # {'x': 183,
154
+ # 'y': 94,
155
+ # 'width': 103,
156
+ # 'height': 109,
157
+ # 'label': 'brown rubber sphere'},
158
+ # ]
159
+
160
+ # # boxes - normalized -> unnormalized
161
+ # boxes = np.array(boxes) * 512
162
+
163
+ custom_annotations = []
164
+ for i in range(len(boxes)):
165
+ box = boxes[i]
166
+ custom_annotations.append({'x': box[0],
167
+ 'y': box[1],
168
+ 'width': box[2] - box[0],
169
+ 'height': box[3] - box[1],
170
+ 'label': grounding_texts[i]})
171
+ # # 1) convert xywh to xyxy
172
+ # # 2) normalize coordinates
173
+ scene = encode_from_custom_annotation(custom_annotations, size=512)
174
+
175
+ print(scene['box_captions'])
176
+ print(scene['boxes_normalized'])
177
+
178
+ pipe = instance.get_model()
179
+
180
+ out = iterinpaint_sample_diffusers(
181
+ pipe, scene, paste=True, verbose=True, size=512, guidance_scale=guidance_scale)
182
+
183
+ final_image = out['generated_images'][-1].copy()
184
+
185
+ # Create Generation GIF
186
+ prompts = out['prompts']
187
+
188
+ fps = 4
189
+
190
+ def create_gif_source_images(images, prompts):
191
+ """Create source images for gif
192
+ Each frame consists of a intermediate image with a prompt as title.
193
+ Don't change size of the original images.
194
+ """
195
+
196
+ step_images = []
197
+ font = ImageFont.truetype("DejaVuSansMono.ttf", size=20)
198
+ for i, img in enumerate(images):
199
+ draw = ImageDraw.Draw(img)
200
+ draw.text((0, 0), prompts[i], (255, 255, 255), font=font)
201
+ step_images.append(img)
202
+ return step_images
203
+
204
+ import imageio
205
+
206
+ step_images = create_gif_source_images(out['generated_images'], prompts)
207
+ print("Number of frames in GIF: {}".format(len(step_images)))
208
+ # create temp path
209
+ import tempfile
210
+ import os
211
+ gif_save_path = os.path.join(tempfile.gettempdir(), 'gen.gif')
212
+
213
+ # create gif
214
+ imageio.mimsave(gif_save_path, step_images, fps=fps)
215
+ print('GIF saved to {}'.format(gif_save_path))
216
+
217
+ out_images = [
218
+ final_image,
219
+ gif_save_path
220
+ ]
221
+
222
+ return out_images
223
+
224
+ def generate(task, language_instruction, grounding_texts, sketch_pad,
225
+ alpha_sample, guidance_scale, batch_size,
226
+ fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image,
227
+ state):
228
+ if 'boxes' not in state:
229
+ state['boxes'] = []
230
+
231
+ boxes = state['boxes']
232
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
233
+ # assert len(boxes) == len(grounding_texts)
234
+ if len(boxes) != len(grounding_texts):
235
+ if len(boxes) < len(grounding_texts):
236
+ raise ValueError("""The number of boxes should be equal to the number of grounding objects.
237
+ Number of boxes drawn: {}, number of grounding tokens: {}.
238
+ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
239
+ grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
240
+
241
+ # # normalize boxes
242
+ # boxes = (np.asarray(boxes) / 512).tolist()
243
+
244
+ print('input boxes: ', boxes)
245
+ print('input grounding_texts: ', grounding_texts)
246
+ print('input language instruction: ', language_instruction)
247
+
248
+ if fix_seed:
249
+ set_seed(rand_seed)
250
+ print('seed set to: ', rand_seed)
251
+
252
+ gen_image, gen_animation = inference(
253
+ language_instruction, grounding_texts, boxes,
254
+ guidance_scale=guidance_scale,
255
+ )
256
+
257
+ # for idx, gen_image in enumerate(gen_images):
258
+
259
+ # if task == 'Grounded Inpainting' and state.get('inpaint_hw', None):
260
+ # hw = min(*state['original_image'].shape[:2])
261
+ # gen_image = sized_center_fill(state['original_image'].copy(), np.array(gen_image.resize((hw, hw))), hw, hw)
262
+ # gen_image = Image.fromarray(gen_image)
263
+
264
+ # gen_images[idx] = gen_image
265
+
266
+ # blank_samples = batch_size % 2 if batch_size > 1 else 0
267
+ # gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \
268
+ # + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
269
+ # + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
270
+
271
+ # gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \
272
+ # + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
273
+
274
+ gen_images = [
275
+ gr.Image.update(value=gen_image, visible=True),
276
+ gr.Image.update(value=gen_animation, visible=True)
277
+ ]
278
+
279
+ return gen_images + [state]
280
+
281
+
282
+ def binarize(x):
283
+ return (x != 0).astype('uint8') * 255
284
+
285
+ def sized_center_crop(img, cropx, cropy):
286
+ y, x = img.shape[:2]
287
+ startx = x // 2 - (cropx // 2)
288
+ starty = y // 2 - (cropy // 2)
289
+ return img[starty:starty+cropy, startx:startx+cropx]
290
+
291
+ def sized_center_fill(img, fill, cropx, cropy):
292
+ y, x = img.shape[:2]
293
+ startx = x // 2 - (cropx // 2)
294
+ starty = y // 2 - (cropy // 2)
295
+ img[starty:starty+cropy, startx:startx+cropx] = fill
296
+ return img
297
+
298
+ def sized_center_mask(img, cropx, cropy):
299
+ y, x = img.shape[:2]
300
+ startx = x // 2 - (cropx // 2)
301
+ starty = y // 2 - (cropy // 2)
302
+ center_region = img[starty:starty+cropy, startx:startx+cropx].copy()
303
+ img = (img * 0.2).astype('uint8')
304
+ img[starty:starty+cropy, startx:startx+cropx] = center_region
305
+ return img
306
+
307
+ def center_crop(img, HW=None, tgt_size=(512, 512)):
308
+ if HW is None:
309
+ H, W = img.shape[:2]
310
+ HW = min(H, W)
311
+ img = sized_center_crop(img, HW, HW)
312
+ img = Image.fromarray(img)
313
+ img = img.resize(tgt_size)
314
+ return np.array(img)
315
+
316
+ def draw(task, input, grounding_texts, new_image_trigger, state):
317
+ if type(input) == dict:
318
+ image = input['image']
319
+ mask = input['mask']
320
+ else:
321
+ mask = input
322
+
323
+ if mask.ndim == 3:
324
+ mask = mask[..., 0]
325
+
326
+ image_scale = 1.0
327
+
328
+ # resize trigger
329
+ if task == "Grounded Inpainting":
330
+ mask_cond = mask.sum() == 0
331
+ # size_cond = mask.shape != (512, 512)
332
+ if mask_cond and 'original_image' not in state:
333
+ image = Image.fromarray(image)
334
+ width, height = image.size
335
+ scale = 600 / min(width, height)
336
+ image = image.resize((int(width * scale), int(height * scale)))
337
+ state['original_image'] = np.array(image).copy()
338
+ image_scale = float(height / width)
339
+ return [None, new_image_trigger + 1, image_scale, state]
340
+ else:
341
+ original_image = state['original_image']
342
+ H, W = original_image.shape[:2]
343
+ image_scale = float(H / W)
344
+
345
+ mask = binarize(mask)
346
+ if mask.shape != (512, 512):
347
+ # assert False, "should not receive any non- 512x512 masks."
348
+ if 'original_image' in state and state['original_image'].shape[:2] == mask.shape:
349
+ mask = center_crop(mask, state['inpaint_hw'])
350
+ image = center_crop(state['original_image'], state['inpaint_hw'])
351
+ else:
352
+ mask = np.zeros((512, 512), dtype=np.uint8)
353
+ # mask = center_crop(mask)
354
+ mask = binarize(mask)
355
+
356
+ if type(mask) != np.ndarray:
357
+ mask = np.array(mask)
358
+
359
+ if mask.sum() == 0 and task != "Grounded Inpainting":
360
+ state = {}
361
+
362
+ if task != 'Grounded Inpainting':
363
+ image = None
364
+ else:
365
+ image = Image.fromarray(image)
366
+
367
+ if 'boxes' not in state:
368
+ state['boxes'] = []
369
+
370
+ if 'masks' not in state or len(state['masks']) == 0:
371
+ state['masks'] = []
372
+ last_mask = np.zeros_like(mask)
373
+ else:
374
+ last_mask = state['masks'][-1]
375
+
376
+ if type(mask) == np.ndarray and mask.size > 1:
377
+ diff_mask = mask - last_mask
378
+ else:
379
+ diff_mask = np.zeros([])
380
+
381
+ if diff_mask.sum() > 0:
382
+ x1x2 = np.where(diff_mask.max(0) != 0)[0]
383
+ y1y2 = np.where(diff_mask.max(1) != 0)[0]
384
+ y1, y2 = y1y2.min(), y1y2.max()
385
+ x1, x2 = x1x2.min(), x1x2.max()
386
+
387
+ if (x2 - x1 > 5) and (y2 - y1 > 5):
388
+ state['masks'].append(mask.copy())
389
+ state['boxes'].append((x1, y1, x2, y2))
390
+
391
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
392
+ grounding_texts = [x for x in grounding_texts if len(x) > 0]
393
+ if len(grounding_texts) < len(state['boxes']):
394
+ grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))]
395
+
396
+ box_image = draw_box(state['boxes'], grounding_texts, image)
397
+
398
+ if box_image is not None and state.get('inpaint_hw', None):
399
+ inpaint_hw = state['inpaint_hw']
400
+ box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw)))
401
+ original_image = state['original_image'].copy()
402
+ box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw)
403
+
404
+ return [box_image, new_image_trigger, image_scale, state]
405
+
406
+ def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
407
+ if task != 'Grounded Inpainting':
408
+ sketch_pad_trigger = sketch_pad_trigger + 1
409
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
410
+ # out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \
411
+ # + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
412
+ # + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
413
+
414
+ out_images = [gr.Image.update(value=None, visible=True) for i in range(1)] \
415
+ + [gr.Image.update(value=None, visible=True) for _ in range(1)]
416
+ state = {}
417
+ return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
418
+
419
+ css = """
420
+ #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
421
+ {
422
+ height: var(--height) !important;
423
+ max-height: var(--height) !important;
424
+ min-height: var(--height) !important;
425
+ }
426
+ #paper-info a {
427
+ color:#008AD7;
428
+ text-decoration: none;
429
+ }
430
+ #paper-info a:hover {
431
+ cursor: pointer;
432
+ text-decoration: none;
433
+ }
434
+ """
435
+
436
+ rescale_js = """
437
+ function(x) {
438
+ const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
439
+ let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
440
+ const image_width = root.querySelector('#img2img_image').clientWidth;
441
+ const target_height = parseInt(image_width * image_scale);
442
+ document.body.style.setProperty('--height', `${target_height}px`);
443
+ root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
444
+ root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
445
+ return x;
446
+ }
447
+ """
448
+
449
+ with Blocks(
450
+ # css=css,
451
+ analytics_enabled=False,
452
+ title="IterInpaint demo",
453
+ ) as main:
454
+ description = """
455
+ <p style="text-align: center; font-weight: bold;">
456
+ <span style="font-size: 28px">IterInpaint CLEVR Demo</span>
457
+ <br>
458
+ <span style="font-size: 18px" id="paper-info">
459
+ [<a href="https://layoutbench.github.io" target="_blank">Project Page</a>]
460
+ [<a href="https://arxiv.org/abs/2304.06671" target="_blank">Paper</a>]
461
+ [<a href="https://github.com/j-min/IterInpaint" target="_blank">GitHub</a>]
462
+ </span>
463
+ </p>
464
+ <span style="font-size: 14px">
465
+ <b>IterInpaint</b> is a new baseline for layout-guided image generation.
466
+ Unlike previous methods that generate all objects in a single step, IterInpaint decomposes image generation process into multiple steps and uses an inpainting model to update regions step-by-step.
467
+ </span>
468
+ <br>
469
+ <br>
470
+ <span style="font-size: 18px" id="instruction">
471
+ Instructions:
472
+ </span>
473
+ <p>
474
+ (1) &#9000;&#65039; Enter the object names in <em> Region Captions</em>
475
+ <br>
476
+ (2) &#128433;&#65039; Draw their corresponding bounding boxes one by one using <em> Sketch Pad</em> -- the parsed boxes will be displayed automatically.
477
+ <br>
478
+ For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/j-min/iterinpaint-CLEVR?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>
479
+ </p>
480
+ """
481
+ gr.HTML(description)
482
+
483
+ with gr.Row():
484
+ with gr.Column(scale=4):
485
+ sketch_pad_trigger = gr.Number(value=0, visible=False)
486
+ sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
487
+ init_white_trigger = gr.Number(value=0, visible=False)
488
+ image_scale = gr.Number(
489
+ value=0, elem_id="image_scale", visible=False)
490
+ new_image_trigger = gr.Number(value=0, visible=False)
491
+
492
+ # task = gr.Radio(
493
+ # choices=["Grounded Generation", 'Grounded Inpainting'],
494
+ # type="value",
495
+ # value="Grounded Generation",
496
+ # label="Task",
497
+ # )
498
+ task = gr.State("Grounded Generation")
499
+ # language_instruction = gr.Textbox(
500
+ # label="Language instruction",
501
+ # )
502
+ language_instruction = gr.State("")
503
+
504
+ grounding_instruction = gr.Textbox(
505
+ label="""
506
+ Region Captions (Separated by semicolon)
507
+ e.g., "blue metal cube; red rubber cylinder"
508
+ """,
509
+ )
510
+ with gr.Row():
511
+ sketch_pad = ImageMask(
512
+ label="Draw bounding boxes", elem_id="img2img_image")
513
+ out_imagebox = gr.Image(type="pil", label="Parsed Layout")
514
+ with gr.Row():
515
+ clear_btn = gr.Button(value='Clear')
516
+ gen_btn = gr.Button(value='Generate')
517
+ with gr.Accordion("Advanced Options", open=False):
518
+ with gr.Column():
519
+ # alpha_sample = gr.Slider(
520
+ # minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (Ο„)")
521
+ alpha_sample = gr.State(0.3)
522
+ guidance_scale = gr.Slider(
523
+ minimum=0, maximum=50, step=0.5, value=4.0, label="Guidance Scale")
524
+ # batch_size = gr.Slider(
525
+ # minimum=1, maximum=4, step=1, value=2, label="Number of Samples")
526
+ # batch_size = gr.Slider(
527
+ # minimum=1, maximum=1, step=1, value=1, label="Number of Samples")
528
+ batch_size = gr.State(1)
529
+ # append_grounding = gr.Checkbox(
530
+ # value=True, label="Append grounding instructions to the caption")
531
+ append_grounding = gr.State(False)
532
+ # use_actual_mask = gr.Checkbox(
533
+ # value=False, label="Use actual mask for inpainting", visible=False)
534
+ use_actual_mask = gr.State(False)
535
+ with gr.Row():
536
+ # fix_seed = gr.Checkbox(value=True, label="Fixed seed")
537
+ fix_seed = gr.State(True)
538
+ rand_seed = gr.Slider(
539
+ minimum=0, maximum=1000, step=1, value=0, label="Seed")
540
+ with gr.Row():
541
+ # use_style_cond = gr.Checkbox(
542
+ # value=False, label="Enable Style Condition")
543
+ # style_cond_image = gr.Image(
544
+ # type="pil", label="Style Condition", visible=False, interactive=True)
545
+ use_style_cond = gr.State(False)
546
+ style_cond_image = gr.State(None)
547
+ with gr.Column(scale=3):
548
+ gr.HTML(
549
+ '<span style="font-size: 20px; font-weight: bold">Generated Image</span>')
550
+ # with gr.Row():
551
+ out_gen_1 = gr.Image(
552
+ type="pil", visible=True, show_label=False)
553
+ gr.HTML(
554
+ '<span style="font-size: 20px; font-weight: bold">Step-by-Step Animation</span>')
555
+ out_gen_2 = gr.Image(
556
+ type="pil", visible=True, show_label=False)
557
+ # with gr.Row():
558
+ # out_gen_3 = gr.Image(
559
+ # type="pil", visible=False, show_label=False)
560
+ # out_gen_4 = gr.Image(
561
+ # type="pil", visible=False, show_label=False)
562
+
563
+ state = gr.State({})
564
+
565
+ class Controller:
566
+ def __init__(self):
567
+ self.calls = 0
568
+ self.tracks = 0
569
+ self.resizes = 0
570
+ self.scales = 0
571
+
572
+ def init_white(self, init_white_trigger):
573
+ self.calls += 1
574
+ return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger+1
575
+
576
+ # def change_n_samples(self, n_samples):
577
+ # blank_samples = n_samples % 2 if n_samples > 1 else 0
578
+ # return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
579
+ # + [gr.Image.update(visible=False)
580
+ # for _ in range(4 - n_samples - blank_samples)]
581
+
582
+ def resize_centercrop(self, state):
583
+ self.resizes += 1
584
+ image = state['original_image'].copy()
585
+ inpaint_hw = int(0.9 * min(*image.shape[:2]))
586
+ state['inpaint_hw'] = inpaint_hw
587
+ image_cc = center_crop(image, inpaint_hw)
588
+ # print(f'resize triggered {self.resizes}', image.shape, '->', image_cc.shape)
589
+ return image_cc, state
590
+
591
+ def resize_masked(self, state):
592
+ self.resizes += 1
593
+ image = state['original_image'].copy()
594
+ inpaint_hw = int(0.9 * min(*image.shape[:2]))
595
+ state['inpaint_hw'] = inpaint_hw
596
+ image_mask = sized_center_mask(image, inpaint_hw, inpaint_hw)
597
+ state['masked_image'] = image_mask.copy()
598
+ # print(f'mask triggered {self.resizes}')
599
+ return image_mask, state
600
+
601
+ def switch_task_hide_cond(self, task):
602
+ cond = False
603
+ if task == "Grounded Generation":
604
+ cond = True
605
+
606
+ return gr.Checkbox.update(visible=cond, value=False), gr.Image.update(value=None, visible=False), gr.Slider.update(visible=cond), gr.Checkbox.update(visible=(not cond), value=False)
607
+
608
+ controller = Controller()
609
+ main.load(
610
+ lambda x: x+1,
611
+ inputs=sketch_pad_trigger,
612
+ outputs=sketch_pad_trigger,
613
+ queue=False)
614
+ sketch_pad.edit(
615
+ draw,
616
+ inputs=[task, sketch_pad, grounding_instruction,
617
+ sketch_pad_resize_trigger, state],
618
+ outputs=[out_imagebox, sketch_pad_resize_trigger,
619
+ image_scale, state],
620
+ queue=False,
621
+ )
622
+ grounding_instruction.change(
623
+ draw,
624
+ inputs=[task, sketch_pad, grounding_instruction,
625
+ sketch_pad_resize_trigger, state],
626
+ outputs=[out_imagebox, sketch_pad_resize_trigger,
627
+ image_scale, state],
628
+ queue=False,
629
+ )
630
+ clear_btn.click(
631
+ clear,
632
+ inputs=[task, sketch_pad_trigger, batch_size, state],
633
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox,
634
+ # image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
635
+ image_scale, out_gen_1, out_gen_2, state],
636
+ queue=False)
637
+ # task.change(
638
+ # partial(clear, switch_task=True),
639
+ # inputs=[task, sketch_pad_trigger, batch_size, state],
640
+ # outputs=[sketch_pad, sketch_pad_trigger, out_imagebox,
641
+ # image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
642
+ # queue=False)
643
+ sketch_pad_trigger.change(
644
+ controller.init_white,
645
+ inputs=[init_white_trigger],
646
+ outputs=[sketch_pad, image_scale, init_white_trigger],
647
+ queue=False)
648
+ sketch_pad_resize_trigger.change(
649
+ controller.resize_masked,
650
+ inputs=[state],
651
+ outputs=[sketch_pad, state],
652
+ queue=False)
653
+ # batch_size.change(
654
+ # controller.change_n_samples,
655
+ # inputs=[batch_size],
656
+ # outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4],
657
+ # queue=False)
658
+ gen_btn.click(
659
+ generate,
660
+ inputs=[
661
+ task, language_instruction, grounding_instruction, sketch_pad,
662
+ alpha_sample, guidance_scale, batch_size,
663
+ fix_seed, rand_seed,
664
+ use_actual_mask,
665
+ append_grounding, style_cond_image,
666
+ state,
667
+ ],
668
+ # outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
669
+ outputs=[out_gen_1, out_gen_2, state],
670
+ queue=True
671
+ )
672
+ sketch_pad_resize_trigger.change(
673
+ None,
674
+ None,
675
+ sketch_pad_resize_trigger,
676
+ _js=rescale_js,
677
+ queue=False)
678
+ init_white_trigger.change(
679
+ None,
680
+ None,
681
+ init_white_trigger,
682
+ _js=rescale_js,
683
+ queue=False)
684
+ # use_style_cond.change(
685
+ # lambda cond: gr.Image.update(visible=cond),
686
+ # use_style_cond,
687
+ # style_cond_image,
688
+ # queue=False)
689
+ # task.change(
690
+ # controller.switch_task_hide_cond,
691
+ # inputs=task,
692
+ # outputs=[use_style_cond, style_cond_image,
693
+ # alpha_sample, use_actual_mask],
694
+ # queue=False)
695
+
696
+ with gr.Column():
697
+ gr.Examples(
698
+ examples=[
699
+ [
700
+ "images/blank.png",
701
+ "blue metal cube",
702
+ ],
703
+ [
704
+ "images/blank.png",
705
+ "green metal cube; red metal sphere; brown rubber cube",
706
+ ],
707
+ [
708
+ "images/blank.png",
709
+ "blue metal cube; brown rubber sphere; gray metal sphere; yellow rubber cylinder; gray metal cylinder; cyan rubber sphere; green rubber cube; red metal cylinder",
710
+ ]
711
+ ],
712
+ inputs=[
713
+ sketch_pad,
714
+ grounding_instruction
715
+ ],
716
+ outputs=None,
717
+ fn=None,
718
+ cache_examples=False,
719
+ )
720
+
721
+ # https://huggingface.co/spaces/gligen/demo
722
+ # add hyperlink
723
+ thank_desc = """
724
+ Thanks
725
+ <a href="https://huggingface.co/spaces/gligen/demo" target="_blank">GLIGEN demo</a>, for providing bounding box parsing module.
726
+ """
727
+ gr.HTML(thank_desc)
728
+
729
+ main.queue(concurrency_count=1, api_open=False)
730
+ # main.launch(share=False, show_api=False, show_error=True)
731
+ main.launch(
732
+ server_name="0.0.0.0",
733
+ share=True,
734
+ # server_port=7864,
735
+ show_api=False, show_error=True
736
+ )
gen_utils.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from PIL import ImageDraw
4
+
5
+ def encode_scene(obj_list, H=320, W=320, src_bbox_format='xywh', tgt_bbox_format='xyxy'):
6
+ """Encode scene into text and bounding boxes
7
+ Args:
8
+ obj_list: list of dicts
9
+ Each dict has keys:
10
+
11
+ 'color': str
12
+ 'material': str
13
+ 'shape': str
14
+ or
15
+ 'caption': str
16
+
17
+ and
18
+
19
+ 'bbox': list of 4 floats (unnormalized)
20
+ [x0, y0, x1, y1] or [x0, y0, w, h]
21
+ """
22
+ box_captions = []
23
+ for obj in obj_list:
24
+ if 'caption' in obj:
25
+ box_caption = obj['caption']
26
+ else:
27
+ box_caption = f"{obj['color']} {obj['material']} {obj['shape']}"
28
+ box_captions += [box_caption]
29
+
30
+ assert src_bbox_format in ['xywh', 'xyxy'], f"src_bbox_format must be 'xywh' or 'xyxy', not {src_bbox_format}"
31
+ assert tgt_bbox_format in ['xywh', 'xyxy'], f"tgt_bbox_format must be 'xywh' or 'xyxy', not {tgt_bbox_format}"
32
+
33
+ boxes_unnormalized = []
34
+ boxes_normalized = []
35
+ for obj in obj_list:
36
+ if src_bbox_format == 'xywh':
37
+ x0, y0, w, h = obj['bbox']
38
+ x1 = x0 + w
39
+ y1 = y0 + h
40
+ elif src_bbox_format == 'xyxy':
41
+ x0, y0, x1, y1 = obj['bbox']
42
+ w = x1 - x0
43
+ h = y1 - y0
44
+ assert x1 > x0, f"x1={x1} <= x0={x0}"
45
+ assert y1 > y0, f"y1={y1} <= y0={y0}"
46
+ assert x1 <= W, f"x1={x1} > W={W}"
47
+ assert y1 <= H, f"y1={y1} > H={H}"
48
+
49
+ if tgt_bbox_format == 'xywh':
50
+ bbox_unnormalized = [x0, y0, w, h]
51
+ bbox_normalized = [x0 / W, y0 / H, w / W, h / H]
52
+
53
+ elif tgt_bbox_format == 'xyxy':
54
+ bbox_unnormalized = [x0, y0, x1, y1]
55
+ bbox_normalized = [x0 / W, y0 / H, x1 / W, y1 / H]
56
+
57
+ boxes_unnormalized += [bbox_unnormalized]
58
+ boxes_normalized += [bbox_normalized]
59
+
60
+ assert len(box_captions) == len(boxes_normalized), f"len(box_captions)={len(box_captions)} != len(boxes_normalized)={len(boxes_normalized)}"
61
+
62
+
63
+ out = {}
64
+ out['box_captions'] = box_captions
65
+ out['boxes_normalized'] = boxes_normalized
66
+ out['boxes_unnormalized'] = boxes_unnormalized
67
+
68
+ return out
69
+
70
+ def encode_from_custom_annotation(custom_annotations, size=512):
71
+ # custom_annotations = [
72
+ # {'x': 83, 'y': 335, 'width': 70, 'height': 69, 'label': 'blue metal cube'},
73
+ # {'x': 162, 'y': 302, 'width': 110, 'height': 138, 'label': 'blue metal cube'},
74
+ # {'x': 274, 'y': 250, 'width': 191, 'height': 234, 'label': 'blue metal cube'},
75
+ # {'x': 14, 'y': 18, 'width': 155, 'height': 205, 'label': 'blue metal cube'},
76
+ # {'x': 175, 'y': 79, 'width': 106, 'height': 119, 'label': 'blue metal cube'},
77
+ # {'x': 288, 'y': 111, 'width': 69, 'height': 63, 'label': 'blue metal cube'}
78
+ # ]
79
+ H, W = size, size
80
+
81
+ objects = []
82
+ for j in range(len(custom_annotations)):
83
+ xyxy = [
84
+ custom_annotations[j]['x'],
85
+ custom_annotations[j]['y'],
86
+ custom_annotations[j]['x'] + custom_annotations[j]['width'],
87
+ custom_annotations[j]['y'] + custom_annotations[j]['height']]
88
+ objects.append({
89
+ 'caption': custom_annotations[j]['label'],
90
+ 'bbox': xyxy,
91
+ })
92
+
93
+ out = encode_scene(objects, H=H, W=W,
94
+ src_bbox_format='xyxy', tgt_bbox_format='xyxy')
95
+
96
+ return out
97
+
98
+
99
+
100
+ #### Below are for HF diffusers
101
+
102
+ def iterinpaint_sample_diffusers(pipe, datum, paste=True, verbose=False, guidance_scale=4.0, size=512, background_instruction='Add gray background'):
103
+ d = datum
104
+
105
+ d['unnormalized_boxes'] = d['boxes_unnormalized']
106
+
107
+ n_total_boxes = len(d['unnormalized_boxes'])
108
+
109
+ context_imgs = []
110
+ mask_imgs = []
111
+ # masked_imgs = []
112
+ generated_images = []
113
+ prompts = []
114
+
115
+ context_img = Image.new('RGB', (size, size))
116
+ # context_draw = ImageDraw.Draw(context_img)
117
+ if verbose:
118
+ print('Initiailzed context image')
119
+
120
+ background_mask_img = Image.new('L', (size, size))
121
+ background_mask_draw = ImageDraw.Draw(background_mask_img)
122
+ background_mask_draw.rectangle([(0, 0), background_mask_img.size], fill=255)
123
+
124
+ for i in range(n_total_boxes):
125
+ if verbose:
126
+ print('Iter: ', i+1, 'total: ', n_total_boxes)
127
+
128
+ target_caption = d['box_captions'][i]
129
+ if verbose:
130
+ print('Drawing ', target_caption)
131
+
132
+ mask_img = Image.new('L', context_img.size)
133
+ mask_draw = ImageDraw.Draw(mask_img)
134
+ mask_draw.rectangle([(0, 0), mask_img.size], fill=0)
135
+
136
+ box = d['unnormalized_boxes'][i]
137
+ if type(box) == list:
138
+ box = torch.tensor(box)
139
+ mask_draw.rectangle(box.long().tolist(), fill=255)
140
+ background_mask_draw.rectangle(box.long().tolist(), fill=0)
141
+
142
+ mask_imgs.append(mask_img.copy())
143
+
144
+
145
+ prompt = f"Add {d['box_captions'][i]}"
146
+
147
+ if verbose:
148
+ print('prompt:', prompt)
149
+ prompts += [prompt]
150
+
151
+ context_imgs.append(context_img.copy())
152
+
153
+ generated_image = pipe(
154
+ prompt,
155
+ context_img,
156
+ mask_img,
157
+ guidance_scale=guidance_scale).images[0]
158
+
159
+ if paste:
160
+ # context_img.paste(generated_image.crop(box.long().tolist()), box.long().tolist())
161
+
162
+
163
+ src_box = box.long().tolist()
164
+
165
+ # x1 -> x1 + 1
166
+ # y1 -> y1 + 1
167
+ paste_box = box.long().tolist()
168
+ paste_box[0] -= 1
169
+ paste_box[1] -= 1
170
+ paste_box[2] += 1
171
+ paste_box[3] += 1
172
+
173
+ box_w = paste_box[2] - paste_box[0]
174
+ box_h = paste_box[3] - paste_box[1]
175
+
176
+ context_img.paste(generated_image.crop(src_box).resize((box_w, box_h)), paste_box)
177
+ generated_images.append(context_img.copy())
178
+ else:
179
+ context_img = generated_image
180
+ generated_images.append(context_img.copy())
181
+
182
+ if verbose:
183
+ print('Fill background')
184
+
185
+ mask_img = background_mask_img
186
+
187
+ mask_imgs.append(mask_img)
188
+
189
+ prompt = background_instruction
190
+
191
+ if verbose:
192
+ print('prompt:', prompt)
193
+ prompts += [prompt]
194
+
195
+ generated_image = pipe(
196
+ prompt,
197
+ context_img,
198
+ mask_img,
199
+ guidance_scale=guidance_scale).images[0]
200
+
201
+ generated_images.append(generated_image)
202
+
203
+ return {
204
+ 'context_imgs': context_imgs,
205
+ 'mask_imgs': mask_imgs,
206
+ 'prompts': prompts,
207
+ 'generated_images': generated_images,
208
+ }
images/blank.png ADDED
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ transformers
3
+ diffusers
4
+ accelerate
5
+ opencv-python
6
+ timm
7
+ einops
8
+ datasets
9
+ rouge_score
10
+ omegaconf
11
+ ftfy
12
+ pycocotools
13
+ pycocoevalcap
14
+ albumentations
15
+ pudb
16
+ imageio
17
+ imageio-ffmpeg
18
+ clean-fid
19
+ h5py
20
+ pillow
21
+ setuptools
22
+ pytorch-lightning==1.5.9
23
+ tensorboardX==2.4.1
24
+ test-tube>=0.7.5
25
+ streamlit>=0.73.1
26
+ torch-fidelity==0.3.0
27
+ torchmetrics==0.6.0
28
+ kornia==0.6
29
+ chardet
30
+ cchardet
31
+ taming-transformers-rom1504
32
+ git+https://github.com/openai/CLIP.git