Haotian Liu commited on
Commit
f9a674e
1 Parent(s): 6c8fcd4

Upload app

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +112 -0
  2. DejaVuSansMono.ttf +0 -0
  3. __init__.py +0 -0
  4. app.py +758 -0
  5. dataset/__init__.py +0 -0
  6. dataset/base_dataset.py +220 -0
  7. dataset/catalog.py +72 -0
  8. dataset/cd_dataset.py +250 -0
  9. dataset/concat_dataset.py +65 -0
  10. dataset/grounding_dataset.py +205 -0
  11. dataset/layout_dataset.py +237 -0
  12. dataset/tsv.py +212 -0
  13. dataset/tsv_dataset.py +326 -0
  14. dataset/utils.py +116 -0
  15. environment.yaml +29 -0
  16. gligen/__init__.py +10 -0
  17. gligen/create_meta.py +170 -0
  18. gligen/distributed.py +122 -0
  19. gligen/evaluator.py +225 -0
  20. gligen/ldm/__init__.py +3 -0
  21. gligen/ldm/data/__init__.py +0 -0
  22. gligen/ldm/data/base.py +23 -0
  23. gligen/ldm/data/imagenet.py +394 -0
  24. gligen/ldm/data/imagenet_clsidx_to_label.txt +1000 -0
  25. gligen/ldm/data/index_synset.yaml +1000 -0
  26. gligen/ldm/data/lsun.py +92 -0
  27. gligen/ldm/lr_scheduler.py +98 -0
  28. gligen/ldm/models/autoencoder.py +52 -0
  29. gligen/ldm/models/diffusion/__init__.py +0 -0
  30. gligen/ldm/models/diffusion/classifier.py +267 -0
  31. gligen/ldm/models/diffusion/ddim.py +134 -0
  32. gligen/ldm/models/diffusion/ddpm.py +72 -0
  33. gligen/ldm/models/diffusion/ldm.py +88 -0
  34. gligen/ldm/models/diffusion/plms.py +162 -0
  35. gligen/ldm/modules/attention.py +387 -0
  36. gligen/ldm/modules/diffusionmodules/__init__.py +0 -0
  37. gligen/ldm/modules/diffusionmodules/model.py +835 -0
  38. gligen/ldm/modules/diffusionmodules/openaimodel.py +489 -0
  39. gligen/ldm/modules/diffusionmodules/positionnet.py +50 -0
  40. gligen/ldm/modules/diffusionmodules/positionnet_with_image.py +68 -0
  41. gligen/ldm/modules/diffusionmodules/util.py +277 -0
  42. gligen/ldm/modules/distributions/__init__.py +0 -0
  43. gligen/ldm/modules/distributions/distributions.py +92 -0
  44. gligen/ldm/modules/ema.py +76 -0
  45. gligen/ldm/modules/encoders/__init__.py +0 -0
  46. gligen/ldm/modules/encoders/modules.py +245 -0
  47. gligen/ldm/modules/encoders/modules_backup.py +234 -0
  48. gligen/ldm/modules/image_degradation/__init__.py +2 -0
  49. gligen/ldm/modules/image_degradation/bsrgan.py +730 -0
  50. gligen/ldm/modules/image_degradation/bsrgan_light.py +650 -0
.gitignore ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IntelliJ project files
2
+ .idea
3
+ *.iml
4
+ out
5
+ gen
6
+
7
+ ### Vim template
8
+ [._]*.s[a-w][a-z]
9
+ [._]s[a-w][a-z]
10
+ *.un~
11
+ Session.vim
12
+ .netrwhist
13
+ *~
14
+
15
+ ### IPythonNotebook template
16
+ # Temporary data
17
+ .ipynb_checkpoints/
18
+
19
+ ### Python template
20
+ # Byte-compiled / optimized / DLL files
21
+ __pycache__/
22
+ *.py[cod]
23
+ *$py.class
24
+
25
+ # C extensions
26
+ *.so
27
+
28
+ # Distribution / packaging
29
+ .Python
30
+ env/
31
+ build/
32
+ develop-eggs/
33
+ dist/
34
+ downloads/
35
+ eggs/
36
+ .eggs/
37
+ #lib/
38
+ #lib64/
39
+ parts/
40
+ sdist/
41
+ var/
42
+ *.egg-info/
43
+ .installed.cfg
44
+ *.egg
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ # Installer logs
53
+ pip-log.txt
54
+ pip-delete-this-directory.txt
55
+
56
+ # Unit test / coverage reports
57
+ htmlcov/
58
+ .tox/
59
+ .coverage
60
+ .coverage.*
61
+ .cache
62
+ nosetests.xml
63
+ coverage.xml
64
+ *,cover
65
+
66
+ # Translations
67
+ *.mo
68
+ *.pot
69
+
70
+ # Django stuff:
71
+ *.log
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ target/
78
+
79
+ *.ipynb
80
+ *.params
81
+ # *.json
82
+ .vscode/
83
+ *.code-workspace/
84
+
85
+ lib/pycocotools/_mask.c
86
+ lib/nms/cpu_nms.c
87
+
88
+ OUTPUT
89
+ OUTPUT/*
90
+ models/*
91
+ DATASET
92
+ DATASET/*
93
+ external/
94
+ MODELS
95
+ MODELS/*
96
+ gradio_cached_examples/*
97
+
98
+ kill.sh
99
+
100
+ draws/
101
+ #:wq
102
+ #plot/figs
103
+
104
+ *venv/*
105
+
106
+ # images
107
+ # images/*
108
+
109
+ create_samples/
110
+ create_samples/*
111
+
112
+ ckpts/*
DejaVuSansMono.ttf ADDED
Binary file (341 kB). View file
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import argparse
4
+ from omegaconf import OmegaConf
5
+ from gligen.task_grounded_generation import grounded_generation_box, load_ckpt
6
+
7
+ import json
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ from functools import partial
11
+ import math
12
+
13
+ from gradio import processing_utils
14
+ from typing import Optional
15
+
16
+ from huggingface_hub import hf_hub_download
17
+ hf_hub_download = partial(hf_hub_download, library_name="gligen_demo")
18
+
19
+
20
+ arg_bool = lambda x: x.lower() == 'true'
21
+
22
+
23
+ def parse_option():
24
+ parser = argparse.ArgumentParser('GLIGen Demo', add_help=False)
25
+ parser.add_argument("--folder", type=str, default="create_samples", help="path to OUTPUT")
26
+ parser.add_argument("--official_ckpt", type=str, default='ckpts/sd-v1-4.ckpt', help="")
27
+ parser.add_argument("--guidance_scale", type=float, default=5, help="")
28
+ parser.add_argument("--alpha_scale", type=float, default=1, help="scale tanh(alpha). If 0, the behaviour is same as original model")
29
+ parser.add_argument("--load-text-box-generation", type=arg_bool, default=True, help="Load text-box generation pipeline.")
30
+ parser.add_argument("--load-text-box-inpainting", type=arg_bool, default=True, help="Load text-box inpainting pipeline.")
31
+ parser.add_argument("--load-text-image-box-generation", type=arg_bool, default=True, help="Load text-image-box generation pipeline.")
32
+ args = parser.parse_args()
33
+ return args
34
+ args = parse_option()
35
+
36
+
37
+ def load_from_hf(repo_id, filename='diffusion_pytorch_model.bin'):
38
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
39
+ return torch.load(cache_file, map_location='cpu')
40
+
41
+ def load_ckpt_config_from_hf(modality):
42
+ ckpt = load_from_hf(f'gligen/{modality}')
43
+ config = load_from_hf('gligen/demo_config_legacy', filename=f'{modality}.pth')
44
+ return ckpt, config
45
+
46
+
47
+ if args.load_text_box_generation:
48
+ pretrained_ckpt_gligen, config = load_ckpt_config_from_hf('gligen-generation-text-box')
49
+ config = OmegaConf.create( config["_content"] ) # config used in training
50
+ config.update( vars(args) )
51
+ config.model['params']['is_inpaint'] = False
52
+ config.model['params']['is_style'] = False
53
+ loaded_model_list = load_ckpt(config, pretrained_ckpt_gligen)
54
+
55
+
56
+ if args.load_text_box_inpainting:
57
+ pretrained_ckpt_gligen_inpaint, config = load_ckpt_config_from_hf('gligen-inpainting-text-box')
58
+ config = OmegaConf.create( config["_content"] ) # config used in training
59
+ config.update( vars(args) )
60
+ config.model['params']['is_inpaint'] = True
61
+ config.model['params']['is_style'] = False
62
+ loaded_model_list_inpaint = load_ckpt(config, pretrained_ckpt_gligen_inpaint)
63
+
64
+
65
+ if args.load_text_image_box_generation:
66
+ pretrained_ckpt_gligen_style, config = load_ckpt_config_from_hf('gligen-generation-text-image-box')
67
+ config = OmegaConf.create( config["_content"] ) # config used in training
68
+ config.update( vars(args) )
69
+ config.model['params']['is_inpaint'] = False
70
+ config.model['params']['is_style'] = True
71
+ loaded_model_list_style = load_ckpt(config, pretrained_ckpt_gligen_style)
72
+
73
+
74
+ def load_clip_model():
75
+ from transformers import CLIPProcessor, CLIPModel
76
+ version = "openai/clip-vit-large-patch14"
77
+ model = CLIPModel.from_pretrained(version).cuda()
78
+ processor = CLIPProcessor.from_pretrained(version)
79
+
80
+ return {
81
+ 'version': version,
82
+ 'model': model,
83
+ 'processor': processor,
84
+ }
85
+
86
+ clip_model = load_clip_model()
87
+
88
+
89
+ class ImageMask(gr.components.Image):
90
+ """
91
+ Sets: source="canvas", tool="sketch"
92
+ """
93
+
94
+ is_template = True
95
+
96
+ def __init__(self, **kwargs):
97
+ super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
98
+
99
+ def preprocess(self, x):
100
+ if x is None:
101
+ return x
102
+ if self.tool == "sketch" and self.source in ["upload", "webcam"] and type(x) != dict:
103
+ decode_image = processing_utils.decode_base64_to_image(x)
104
+ width, height = decode_image.size
105
+ mask = np.zeros((height, width, 4), dtype=np.uint8)
106
+ mask[..., -1] = 255
107
+ mask = self.postprocess(mask)
108
+ x = {'image': x, 'mask': mask}
109
+ return super().preprocess(x)
110
+
111
+
112
+ class Blocks(gr.Blocks):
113
+
114
+ def __init__(
115
+ self,
116
+ theme: str = "default",
117
+ analytics_enabled: Optional[bool] = None,
118
+ mode: str = "blocks",
119
+ title: str = "Gradio",
120
+ css: Optional[str] = None,
121
+ **kwargs,
122
+ ):
123
+
124
+ self.extra_configs = {
125
+ 'thumbnail': kwargs.pop('thumbnail', ''),
126
+ 'url': kwargs.pop('url', 'https://gradio.app/'),
127
+ 'creator': kwargs.pop('creator', '@teamGradio'),
128
+ }
129
+
130
+ super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
131
+
132
+ def get_config_file(self):
133
+ config = super(Blocks, self).get_config_file()
134
+
135
+ for k, v in self.extra_configs.items():
136
+ config[k] = v
137
+
138
+ return config
139
+
140
+ '''
141
+ inference model
142
+ '''
143
+
144
+ @torch.no_grad()
145
+ def inference(task, language_instruction, grounding_instruction, inpainting_boxes_nodrop, image,
146
+ alpha_sample, guidance_scale, batch_size,
147
+ fix_seed, rand_seed, actual_mask, style_image,
148
+ *args, **kwargs):
149
+ grounding_instruction = json.loads(grounding_instruction)
150
+ phrase_list, location_list = [], []
151
+ for k, v in grounding_instruction.items():
152
+ phrase_list.append(k)
153
+ location_list.append(v)
154
+
155
+ placeholder_image = Image.open('images/teddy.jpg').convert("RGB")
156
+ image_list = [placeholder_image] * len(phrase_list) # placeholder input for visual prompt, which is disabled
157
+
158
+ batch_size = int(batch_size)
159
+ if not 1 <= batch_size <= 2:
160
+ batch_size = 2
161
+
162
+ if style_image == None:
163
+ has_text_mask = 1
164
+ has_image_mask = 0 # then we hack above 'image_list'
165
+ else:
166
+ valid_phrase_len = len(phrase_list)
167
+
168
+ phrase_list += ['placeholder']
169
+ has_text_mask = [1]*valid_phrase_len + [0]
170
+
171
+ image_list = [placeholder_image]*valid_phrase_len + [style_image]
172
+ has_image_mask = [0]*valid_phrase_len + [1]
173
+
174
+ location_list += [ [0.0, 0.0, 1, 0.01] ] # style image grounding location
175
+
176
+ if task == 'Grounded Inpainting':
177
+ alpha_sample = 1.0
178
+
179
+ instruction = dict(
180
+ prompt = language_instruction,
181
+ phrases = phrase_list,
182
+ images = image_list,
183
+ locations = location_list,
184
+ alpha_type = [alpha_sample, 0, 1.0 - alpha_sample],
185
+ has_text_mask = has_text_mask,
186
+ has_image_mask = has_image_mask,
187
+ save_folder_name = language_instruction,
188
+ guidance_scale = guidance_scale,
189
+ batch_size = batch_size,
190
+ fix_seed = bool(fix_seed),
191
+ rand_seed = int(rand_seed),
192
+ actual_mask = actual_mask,
193
+ inpainting_boxes_nodrop = inpainting_boxes_nodrop,
194
+ )
195
+
196
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
197
+ if task == 'Grounded Generation':
198
+ if style_image == None:
199
+ return grounded_generation_box(loaded_model_list, instruction, *args, **kwargs)
200
+ else:
201
+ return grounded_generation_box(loaded_model_list_style, instruction, *args, **kwargs)
202
+ elif task == 'Grounded Inpainting':
203
+ assert image is not None
204
+ instruction['input_image'] = image.convert("RGB")
205
+ return grounded_generation_box(loaded_model_list_inpaint, instruction, *args, **kwargs)
206
+
207
+
208
+ def draw_box(boxes=[], texts=[], img=None):
209
+ if len(boxes) == 0 and img is None:
210
+ return None
211
+
212
+ if img is None:
213
+ img = Image.new('RGB', (512, 512), (255, 255, 255))
214
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
215
+ draw = ImageDraw.Draw(img)
216
+ font = ImageFont.truetype("DejaVuSansMono.ttf", size=18)
217
+ for bid, box in enumerate(boxes):
218
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
219
+ anno_text = texts[bid]
220
+ draw.rectangle([box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]], outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4)
221
+ draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size*1.2)], anno_text, font=font, fill=(255,255,255))
222
+ return img
223
+
224
+ def get_concat(ims):
225
+ if len(ims) == 1:
226
+ n_col = 1
227
+ else:
228
+ n_col = 2
229
+ n_row = math.ceil(len(ims) / 2)
230
+ dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white")
231
+ for i, im in enumerate(ims):
232
+ row_id = i // n_col
233
+ col_id = i % n_col
234
+ dst.paste(im, (im.width * col_id, im.height * row_id))
235
+ return dst
236
+
237
+
238
+ def auto_append_grounding(language_instruction, grounding_texts):
239
+ for grounding_text in grounding_texts:
240
+ if grounding_text not in language_instruction and grounding_text != 'auto':
241
+ language_instruction += "; " + grounding_text
242
+ print(language_instruction)
243
+ return language_instruction
244
+
245
+
246
+
247
+
248
+ def generate(task, language_instruction, grounding_texts, sketch_pad,
249
+ alpha_sample, guidance_scale, batch_size,
250
+ fix_seed, rand_seed, use_actual_mask, append_grounding, style_cond_image,
251
+ state):
252
+ if 'boxes' not in state:
253
+ state['boxes'] = []
254
+
255
+ boxes = state['boxes']
256
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
257
+ assert len(boxes) == len(grounding_texts)
258
+ boxes = (np.asarray(boxes) / 512).tolist()
259
+ grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes)})
260
+
261
+ image = None
262
+ actual_mask = None
263
+ if task == 'Grounded Inpainting':
264
+ image = state.get('original_image', sketch_pad['image']).copy()
265
+ image = center_crop(image)
266
+ image = Image.fromarray(image)
267
+
268
+ if use_actual_mask:
269
+ actual_mask = sketch_pad['mask'].copy()
270
+ if actual_mask.ndim == 3:
271
+ actual_mask = actual_mask[..., 0]
272
+ actual_mask = center_crop(actual_mask, tgt_size=(64, 64))
273
+ actual_mask = torch.from_numpy(actual_mask == 0).float()
274
+
275
+ if state.get('inpaint_hw', None):
276
+ boxes = np.asarray(boxes) * 0.9 + 0.05
277
+ boxes = boxes.tolist()
278
+ grounding_instruction = json.dumps({obj: box for obj,box in zip(grounding_texts, boxes) if obj != 'auto'})
279
+
280
+ if append_grounding:
281
+ language_instruction = auto_append_grounding(language_instruction, grounding_texts)
282
+
283
+ gen_images, gen_overlays = inference(
284
+ task, language_instruction, grounding_instruction, boxes, image,
285
+ alpha_sample, guidance_scale, batch_size,
286
+ fix_seed, rand_seed, actual_mask, style_cond_image, clip_model=clip_model,
287
+ )
288
+
289
+ for idx, gen_image in enumerate(gen_images):
290
+
291
+ if task == 'Grounded Inpainting' and state.get('inpaint_hw', None):
292
+ hw = min(*state['original_image'].shape[:2])
293
+ gen_image = sized_center_fill(state['original_image'].copy(), np.array(gen_image.resize((hw, hw))), hw, hw)
294
+ gen_image = Image.fromarray(gen_image)
295
+
296
+ gen_images[idx] = gen_image
297
+
298
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
299
+ gen_images = [gr.Image.update(value=x, visible=True) for i,x in enumerate(gen_images)] \
300
+ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
301
+ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
302
+
303
+ return gen_images + [state]
304
+
305
+
306
+ def binarize(x):
307
+ return (x != 0).astype('uint8') * 255
308
+
309
+ def sized_center_crop(img, cropx, cropy):
310
+ y, x = img.shape[:2]
311
+ startx = x // 2 - (cropx // 2)
312
+ starty = y // 2 - (cropy // 2)
313
+ return img[starty:starty+cropy, startx:startx+cropx]
314
+
315
+ def sized_center_fill(img, fill, cropx, cropy):
316
+ y, x = img.shape[:2]
317
+ startx = x // 2 - (cropx // 2)
318
+ starty = y // 2 - (cropy // 2)
319
+ img[starty:starty+cropy, startx:startx+cropx] = fill
320
+ return img
321
+
322
+ def sized_center_mask(img, cropx, cropy):
323
+ y, x = img.shape[:2]
324
+ startx = x // 2 - (cropx // 2)
325
+ starty = y // 2 - (cropy // 2)
326
+ center_region = img[starty:starty+cropy, startx:startx+cropx].copy()
327
+ img = (img * 0.2).astype('uint8')
328
+ img[starty:starty+cropy, startx:startx+cropx] = center_region
329
+ return img
330
+
331
+ def center_crop(img, HW=None, tgt_size=(512, 512)):
332
+ if HW is None:
333
+ H, W = img.shape[:2]
334
+ HW = min(H, W)
335
+ img = sized_center_crop(img, HW, HW)
336
+ img = Image.fromarray(img)
337
+ img = img.resize(tgt_size)
338
+ return np.array(img)
339
+
340
+ def draw(task, input, grounding_texts, new_image_trigger, state):
341
+ if type(input) == dict:
342
+ image = input['image']
343
+ mask = input['mask']
344
+ else:
345
+ mask = input
346
+
347
+ if mask.ndim == 3:
348
+ mask = mask[..., 0]
349
+
350
+ image_scale = 1.0
351
+
352
+ # resize trigger
353
+ if task == "Grounded Inpainting":
354
+ mask_cond = mask.sum() == 0
355
+ # size_cond = mask.shape != (512, 512)
356
+ if mask_cond and 'original_image' not in state:
357
+ image = Image.fromarray(image)
358
+ width, height = image.size
359
+ scale = 600 / min(width, height)
360
+ image = image.resize((int(width * scale), int(height * scale)))
361
+ state['original_image'] = np.array(image).copy()
362
+ image_scale = float(height / width)
363
+ return [None, new_image_trigger + 1, image_scale, state]
364
+ else:
365
+ original_image = state['original_image']
366
+ H, W = original_image.shape[:2]
367
+ image_scale = float(H / W)
368
+
369
+ mask = binarize(mask)
370
+ if mask.shape != (512, 512):
371
+ # assert False, "should not receive any non- 512x512 masks."
372
+ if 'original_image' in state and state['original_image'].shape[:2] == mask.shape:
373
+ mask = center_crop(mask, state['inpaint_hw'])
374
+ image = center_crop(state['original_image'], state['inpaint_hw'])
375
+ else:
376
+ mask = np.zeros((512, 512), dtype=np.uint8)
377
+ # mask = center_crop(mask)
378
+ mask = binarize(mask)
379
+
380
+ if type(mask) != np.ndarray:
381
+ mask = np.array(mask)
382
+
383
+ if mask.sum() == 0 and task != "Grounded Inpainting":
384
+ state = {}
385
+
386
+ if task != 'Grounded Inpainting':
387
+ image = None
388
+ else:
389
+ image = Image.fromarray(image)
390
+
391
+ if 'boxes' not in state:
392
+ state['boxes'] = []
393
+
394
+ if 'masks' not in state or len(state['masks']) == 0:
395
+ state['masks'] = []
396
+ last_mask = np.zeros_like(mask)
397
+ else:
398
+ last_mask = state['masks'][-1]
399
+
400
+ if type(mask) == np.ndarray and mask.size > 1:
401
+ diff_mask = mask - last_mask
402
+ else:
403
+ diff_mask = np.zeros([])
404
+
405
+ if diff_mask.sum() > 0:
406
+ x1x2 = np.where(diff_mask.max(0) != 0)[0]
407
+ y1y2 = np.where(diff_mask.max(1) != 0)[0]
408
+ y1, y2 = y1y2.min(), y1y2.max()
409
+ x1, x2 = x1x2.min(), x1x2.max()
410
+
411
+ if (x2 - x1 > 5) and (y2 - y1 > 5):
412
+ state['masks'].append(mask.copy())
413
+ state['boxes'].append((x1, y1, x2, y2))
414
+
415
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
416
+ grounding_texts = [x for x in grounding_texts if len(x) > 0]
417
+ if len(grounding_texts) < len(state['boxes']):
418
+ grounding_texts += [f'Obj. {bid+1}' for bid in range(len(grounding_texts), len(state['boxes']))]
419
+
420
+ box_image = draw_box(state['boxes'], grounding_texts, image)
421
+
422
+ if box_image is not None and state.get('inpaint_hw', None):
423
+ inpaint_hw = state['inpaint_hw']
424
+ box_image_resize = np.array(box_image.resize((inpaint_hw, inpaint_hw)))
425
+ original_image = state['original_image'].copy()
426
+ box_image = sized_center_fill(original_image, box_image_resize, inpaint_hw, inpaint_hw)
427
+
428
+ return [box_image, new_image_trigger, image_scale, state]
429
+
430
+ def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
431
+ if task != 'Grounded Inpainting':
432
+ sketch_pad_trigger = sketch_pad_trigger + 1
433
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
434
+ out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)] \
435
+ + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
436
+ + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
437
+ state = {}
438
+ return [None, sketch_pad_trigger, None, 1.0] + out_images + [state]
439
+
440
+ css = """
441
+ #generate-btn {
442
+ --tw-border-opacity: 1;
443
+ border-color: rgb(255 216 180 / var(--tw-border-opacity));
444
+ --tw-gradient-from: rgb(255 216 180 / .7);
445
+ --tw-gradient-to: rgb(255 216 180 / 0);
446
+ --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
447
+ --tw-gradient-to: rgb(255 176 102 / .8);
448
+ --tw-text-opacity: 1;
449
+ color: rgb(238 116 0 / var(--tw-text-opacity));
450
+ }
451
+ #img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img
452
+ {
453
+ height: var(--height) !important;
454
+ max-height: var(--height) !important;
455
+ min-height: var(--height) !important;
456
+ }
457
+ #mirrors a:hover {
458
+ cursor:pointer;
459
+ }
460
+ #paper-info a {
461
+ color:#008AD7;
462
+ }
463
+ #paper-info a:hover {
464
+ cursor: pointer;
465
+ }
466
+ """
467
+
468
+ rescale_js = """
469
+ function(x) {
470
+ const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
471
+ let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
472
+ const image_width = root.querySelector('#img2img_image').clientWidth;
473
+ const target_height = parseInt(image_width * image_scale);
474
+ document.body.style.setProperty('--height', `${target_height}px`);
475
+ root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
476
+ root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
477
+ return x;
478
+ }
479
+ """
480
+
481
+ mirror_js = """
482
+ function () {
483
+ const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
484
+ const mirrors_div = root.querySelector('#mirrors');
485
+ const current_url = window.location.href;
486
+ const mirrors = [
487
+ 'https://dev.hliu.cc/gligen_mirror1/',
488
+ 'https://dev.hliu.cc/gligen_mirror2/',
489
+ ];
490
+
491
+ let mirror_html = '';
492
+ mirror_html += '[<a href="https://gligen.github.io" target="_blank" style="">Project Page</a>]';
493
+ mirror_html += '[<a href="https://arxiv.org/abs/2301.07093" target="_blank" style="">Paper</a>]';
494
+ mirror_html += '[<a href="https://github.com/gligen/GLIGEN" target="_blank" style="">GitHub Repo</a>]';
495
+ mirror_html += '&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;';
496
+ mirror_html += 'Mirrors: ';
497
+
498
+ mirrors.forEach((e, index) => {
499
+ let cur_index = index + 1;
500
+ if (current_url.includes(e)) {
501
+ mirror_html += `[Mirror ${cur_index}] `;
502
+ } else {
503
+ mirror_html += `[<a onclick="window.location.href = '${e}'">Mirror ${cur_index}</a>] `;
504
+ }
505
+ });
506
+
507
+ mirror_html = `<div class="output-markdown gr-prose" style="max-width: 100%;"><h3 style="text-align: center" id="paper-info">${mirror_html}</h3></div>`;
508
+
509
+ mirrors_div.innerHTML = mirror_html;
510
+ }
511
+ """
512
+
513
+ with Blocks(
514
+ css=css,
515
+ analytics_enabled=False,
516
+ title="GLIGen demo",
517
+ ) as main:
518
+ gr.Markdown('<h1 style="text-align: center;">GLIGen: Open-Set Grounded Text-to-Image Generation</h1>')
519
+ gr.Markdown("""<h3 style="text-align: center" id="paper-info">
520
+ [<a href="https://gligen.github.io" target="_blank" style="">Project Page</a>]
521
+ [<a href="https://arxiv.org/abs/2301.07093" target="_blank" style="">Paper</a>]
522
+ [<a href="https://github.com/gligen/GLIGEN" target="_blank" style="">GitHub Repo</a>]
523
+ </h3>""")
524
+ # gr.HTML("", elem_id="mirrors")
525
+ gr.Markdown("To ground concepts of interest with desired spatial specification, please (1) &#9000;&#65039; enter the concept names in <em> Grounding Instruction</em>, and (2) &#128433;&#65039; draw their corresponding bounding boxes one by one using <em> Sketch Pad</em> -- the parsed boxes will be displayed automatically.")
526
+ with gr.Row():
527
+ with gr.Column(scale=4):
528
+ sketch_pad_trigger = gr.Number(value=0, visible=False)
529
+ sketch_pad_resize_trigger = gr.Number(value=0, visible=False)
530
+ init_white_trigger = gr.Number(value=0, visible=False)
531
+ image_scale = gr.Number(value=0, elem_id="image_scale", visible=False)
532
+ new_image_trigger = gr.Number(value=0, visible=False)
533
+
534
+ task = gr.Radio(
535
+ choices=["Grounded Generation", 'Grounded Inpainting'],
536
+ type="value",
537
+ value="Grounded Generation",
538
+ label="Task",
539
+ )
540
+ language_instruction = gr.Textbox(
541
+ label="Language instruction",
542
+ )
543
+ grounding_instruction = gr.Textbox(
544
+ label="Grounding instruction (Separated by semicolon)",
545
+ )
546
+ with gr.Row():
547
+ sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image")
548
+ out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
549
+ with gr.Row():
550
+ clear_btn = gr.Button(value='Clear')
551
+ gen_btn = gr.Button(value='Generate', elem_id="generate-btn")
552
+ with gr.Accordion("Advanced Options", open=False):
553
+ with gr.Column():
554
+ alpha_sample = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.3, label="Scheduled Sampling (τ)")
555
+ guidance_scale = gr.Slider(minimum=0, maximum=50, step=0.5, value=7.5, label="Guidance Scale")
556
+ batch_size = gr.Slider(minimum=1, maximum=2, step=1, value=2, label="Number of Samples")
557
+ append_grounding = gr.Checkbox(value=True, label="Append grounding instructions to the caption")
558
+ use_actual_mask = gr.Checkbox(value=False, label="Use actual mask for inpainting", visible=False)
559
+ with gr.Row():
560
+ fix_seed = gr.Checkbox(value=True, label="Fixed seed")
561
+ rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=0, label="Seed")
562
+ with gr.Row():
563
+ use_style_cond = gr.Checkbox(value=False, label="Enable Style Condition")
564
+ style_cond_image = gr.Image(type="pil", label="Style Condition", visible=False, interactive=True)
565
+ with gr.Column(scale=4):
566
+ gr.Markdown("### Generated Images")
567
+ with gr.Row():
568
+ out_gen_1 = gr.Image(type="pil", visible=True, show_label=False)
569
+ out_gen_2 = gr.Image(type="pil", visible=True, show_label=False)
570
+ with gr.Row():
571
+ out_gen_3 = gr.Image(type="pil", visible=False, show_label=False)
572
+ out_gen_4 = gr.Image(type="pil", visible=False, show_label=False)
573
+
574
+ state = gr.State({})
575
+
576
+ class Controller:
577
+ def __init__(self):
578
+ self.calls = 0
579
+ self.tracks = 0
580
+ self.resizes = 0
581
+ self.scales = 0
582
+
583
+ def init_white(self, init_white_trigger):
584
+ self.calls += 1
585
+ return np.ones((512, 512), dtype='uint8') * 255, 1.0, init_white_trigger+1
586
+
587
+ def change_n_samples(self, n_samples):
588
+ blank_samples = n_samples % 2 if n_samples > 1 else 0
589
+ return [gr.Image.update(visible=True) for _ in range(n_samples + blank_samples)] \
590
+ + [gr.Image.update(visible=False) for _ in range(4 - n_samples - blank_samples)]
591
+
592
+ def resize_centercrop(self, state):
593
+ self.resizes += 1
594
+ image = state['original_image'].copy()
595
+ inpaint_hw = int(0.9 * min(*image.shape[:2]))
596
+ state['inpaint_hw'] = inpaint_hw
597
+ image_cc = center_crop(image, inpaint_hw)
598
+ # print(f'resize triggered {self.resizes}', image.shape, '->', image_cc.shape)
599
+ return image_cc, state
600
+
601
+ def resize_masked(self, state):
602
+ self.resizes += 1
603
+ image = state['original_image'].copy()
604
+ inpaint_hw = int(0.9 * min(*image.shape[:2]))
605
+ state['inpaint_hw'] = inpaint_hw
606
+ image_mask = sized_center_mask(image, inpaint_hw, inpaint_hw)
607
+ state['masked_image'] = image_mask.copy()
608
+ # print(f'mask triggered {self.resizes}')
609
+ return image_mask, state
610
+
611
+ def switch_task_hide_cond(self, task):
612
+ cond = False
613
+ if task == "Grounded Generation":
614
+ cond = True
615
+
616
+ return gr.Checkbox.update(visible=cond, value=False), gr.Image.update(value=None, visible=False), gr.Slider.update(visible=cond), gr.Checkbox.update(visible=(not cond), value=False)
617
+
618
+ controller = Controller()
619
+ main.load(
620
+ lambda x:x+1,
621
+ inputs=sketch_pad_trigger,
622
+ outputs=sketch_pad_trigger,
623
+ queue=False)
624
+ sketch_pad.edit(
625
+ draw,
626
+ inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
627
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
628
+ queue=False,
629
+ )
630
+ grounding_instruction.change(
631
+ draw,
632
+ inputs=[task, sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
633
+ outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
634
+ queue=False,
635
+ )
636
+ clear_btn.click(
637
+ clear,
638
+ inputs=[task, sketch_pad_trigger, batch_size, state],
639
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
640
+ queue=False)
641
+ task.change(
642
+ partial(clear, switch_task=True),
643
+ inputs=[task, sketch_pad_trigger, batch_size, state],
644
+ outputs=[sketch_pad, sketch_pad_trigger, out_imagebox, image_scale, out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
645
+ queue=False)
646
+ sketch_pad_trigger.change(
647
+ controller.init_white,
648
+ inputs=[init_white_trigger],
649
+ outputs=[sketch_pad, image_scale, init_white_trigger],
650
+ queue=False)
651
+ sketch_pad_resize_trigger.change(
652
+ controller.resize_masked,
653
+ inputs=[state],
654
+ outputs=[sketch_pad, state],
655
+ queue=False)
656
+ batch_size.change(
657
+ controller.change_n_samples,
658
+ inputs=[batch_size],
659
+ outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4],
660
+ queue=False)
661
+ gen_btn.click(
662
+ generate,
663
+ inputs=[
664
+ task, language_instruction, grounding_instruction, sketch_pad,
665
+ alpha_sample, guidance_scale, batch_size,
666
+ fix_seed, rand_seed,
667
+ use_actual_mask,
668
+ append_grounding, style_cond_image,
669
+ state,
670
+ ],
671
+ outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
672
+ queue=True
673
+ )
674
+ sketch_pad_resize_trigger.change(
675
+ None,
676
+ None,
677
+ sketch_pad_resize_trigger,
678
+ _js=rescale_js,
679
+ queue=False)
680
+ init_white_trigger.change(
681
+ None,
682
+ None,
683
+ init_white_trigger,
684
+ _js=rescale_js,
685
+ queue=False)
686
+ use_style_cond.change(
687
+ lambda cond: gr.Image.update(visible=cond),
688
+ use_style_cond,
689
+ style_cond_image,
690
+ queue=False)
691
+ task.change(
692
+ controller.switch_task_hide_cond,
693
+ inputs=task,
694
+ outputs=[use_style_cond, style_cond_image, alpha_sample, use_actual_mask],
695
+ queue=False)
696
+
697
+ with gr.Column():
698
+ gr.Examples(
699
+ examples=[
700
+ [
701
+ "images/blank.png",
702
+ "Grounded Generation",
703
+ "a dog and an apple",
704
+ "a dog;an apple",
705
+ ],
706
+ [
707
+ "images/blank.png",
708
+ "Grounded Generation",
709
+ "John Lennon is using a pc",
710
+ "John Lennon;a pc",
711
+ [
712
+ "images/blank.png",
713
+ "Grounded Generation",
714
+ "a painting of a fox sitting in a field at sunrise in the style of Claude Mone",
715
+ "fox;sunrise",
716
+ ],
717
+ ],
718
+ [
719
+ "images/blank.png",
720
+ "Grounded Generation",
721
+ "a beautiful painting of hot dog by studio ghibli, octane render, brilliantly coloured",
722
+ "hot dog",
723
+ ],
724
+ [
725
+ "images/blank.png",
726
+ "Grounded Generation",
727
+ "a sport car, unreal engine, global illumination, ray tracing",
728
+ "a sport car",
729
+ ],
730
+ [
731
+ "images/flower_beach.jpg",
732
+ "Grounded Inpainting",
733
+ "a squirrel and the space needle",
734
+ "a squirrel;the space needle",
735
+ ],
736
+ [
737
+ "images/arg_corgis.jpeg",
738
+ "Grounded Inpainting",
739
+ "a dog and a birthday cake",
740
+ "a dog; a birthday cake",
741
+ ],
742
+ [
743
+ "images/teddy.jpg",
744
+ "Grounded Inpainting",
745
+ "a teddy bear wearing a santa claus red shirt; holding a Christmas gift box on hand",
746
+ "a santa claus shirt; a Christmas gift box",
747
+ ],
748
+ ],
749
+ inputs=[sketch_pad, task, language_instruction, grounding_instruction],
750
+ outputs=None,
751
+ fn=None,
752
+ cache_examples=False,
753
+ )
754
+
755
+ main.queue(concurrency_count=1, api_open=False)
756
+ main.launch(share=False, show_api=False)
757
+
758
+
dataset/__init__.py ADDED
File without changes
dataset/base_dataset.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw
3
+ import torchvision.transforms as transforms
4
+ import torchvision
5
+ from zipfile import ZipFile
6
+ import os
7
+ import multiprocessing
8
+ import math
9
+ import numpy as np
10
+ import random
11
+ from io import BytesIO
12
+
13
+ VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
14
+
15
+
16
+ def check_filenames_in_zipdata(filenames, ziproot):
17
+ samples = []
18
+ for fst in ZipFile(ziproot).infolist():
19
+ fname = fst.filename
20
+ if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
21
+ continue
22
+ if os.path.splitext(fname)[1].lower() in VALID_IMAGE_TYPES:
23
+ samples.append((fname))
24
+ filenames = set(filenames)
25
+ samples = set(samples)
26
+ assert filenames.issubset(samples), 'Something wrong with your zip data'
27
+
28
+
29
+
30
+ def draw_box(img, boxes):
31
+ colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
32
+ draw = ImageDraw.Draw(img)
33
+ for bid, box in enumerate(boxes):
34
+ draw.rectangle([box[0], box[1], box[2], box[3]], outline =colors[bid % len(colors)], width=4)
35
+ # draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) # x0 y0 x1 y1
36
+ return img
37
+
38
+
39
+
40
+ def to_valid(x0, y0, x1, y1, image_size, min_box_size):
41
+ valid = True
42
+
43
+ if x0>image_size or y0>image_size or x1<0 or y1<0:
44
+ valid = False # no way to make this box vide, it is completely cropped out
45
+ return valid, (None, None, None, None)
46
+
47
+ x0 = max(x0, 0)
48
+ y0 = max(y0, 0)
49
+ x1 = min(x1, image_size)
50
+ y1 = min(y1, image_size)
51
+
52
+ if (x1-x0)*(y1-y0) / (image_size*image_size) < min_box_size:
53
+ valid = False
54
+ return valid, (None, None, None, None)
55
+
56
+ return valid, (x0, y0, x1, y1)
57
+
58
+
59
+
60
+
61
+
62
+ def recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, image_size, min_box_size):
63
+ """
64
+ x,y,w,h: the original annotation corresponding to the raw image size.
65
+ trans_info: what resizing and cropping have been applied to the raw image
66
+ image_size: what is the final image size
67
+ """
68
+
69
+ x0 = x * trans_info["performed_scale"] - trans_info['crop_x']
70
+ y0 = y * trans_info["performed_scale"] - trans_info['crop_y']
71
+ x1 = (x + w) * trans_info["performed_scale"] - trans_info['crop_x']
72
+ y1 = (y + h) * trans_info["performed_scale"] - trans_info['crop_y']
73
+
74
+
75
+ # at this point, box annotation has been recalculated based on scaling and cropping
76
+ # but some point may fall off the image_size region (e.g., negative value), thus we
77
+ # need to clamp them into 0-image_size. But if all points falling outsize of image
78
+ # region, then we will consider this is an invalid box.
79
+ valid, (x0, y0, x1, y1) = to_valid(x0, y0, x1, y1, image_size, min_box_size)
80
+
81
+ if valid:
82
+ # we also perform random flip.
83
+ # Here boxes are valid, and are based on image_size
84
+ if trans_info["performed_flip"]:
85
+ x0, x1 = image_size-x1, image_size-x0
86
+
87
+ return valid, (x0, y0, x1, y1)
88
+
89
+
90
+
91
+ class BaseDataset(torch.utils.data.Dataset):
92
+ def __init__(self, image_root, random_crop, random_flip, image_size):
93
+ super().__init__()
94
+ self.image_root = image_root
95
+ self.random_crop = random_crop
96
+ self.random_flip = random_flip
97
+ self.image_size = image_size
98
+ self.use_zip = False
99
+
100
+ if image_root[-4::] == 'zip':
101
+ self.use_zip = True
102
+ self.zip_dict = {}
103
+
104
+ if self.random_crop:
105
+ assert False, 'NOT IMPLEMENTED'
106
+
107
+
108
+ def fetch_zipfile(self, ziproot):
109
+ pid = multiprocessing.current_process().pid # get pid of this process.
110
+ if pid not in self.zip_dict:
111
+ self.zip_dict[pid] = ZipFile(ziproot)
112
+ zip_file = self.zip_dict[pid]
113
+ return zip_file
114
+
115
+ def fetch_image(self, filename):
116
+ if self.use_zip:
117
+ zip_file = self.fetch_zipfile(self.image_root)
118
+ image = Image.open( BytesIO(zip_file.read(filename)) ).convert('RGB')
119
+ return image
120
+ else:
121
+ image = Image.open( os.path.join(self.image_root,filename) ).convert('RGB')
122
+ return image
123
+
124
+
125
+ def vis_getitem_data(self, index=None, out=None, return_tensor=False, name="res.jpg", print_caption=True):
126
+
127
+ if out is None:
128
+ out = self[index]
129
+
130
+ img = torchvision.transforms.functional.to_pil_image( out["image"]*0.5+0.5 )
131
+ canvas = torchvision.transforms.functional.to_pil_image( torch.ones_like(out["image"]) )
132
+ W, H = img.size
133
+
134
+ if print_caption:
135
+ caption = out["caption"]
136
+ print(caption)
137
+ print(" ")
138
+
139
+ boxes = []
140
+ for box in out["boxes"]:
141
+ x0,y0,x1,y1 = box
142
+ boxes.append( [float(x0*W), float(y0*H), float(x1*W), float(y1*H)] )
143
+ img = draw_box(img, boxes)
144
+
145
+ if return_tensor:
146
+ return torchvision.transforms.functional.to_tensor(img)
147
+ else:
148
+ img.save(name)
149
+
150
+
151
+ def transform_image(self, pil_image):
152
+ if self.random_crop:
153
+ assert False
154
+ arr = random_crop_arr(pil_image, self.image_size)
155
+ else:
156
+ arr, info = center_crop_arr(pil_image, self.image_size)
157
+
158
+ info["performed_flip"] = False
159
+ if self.random_flip and random.random()<0.5:
160
+ arr = arr[:, ::-1]
161
+ info["performed_flip"] = True
162
+
163
+ arr = arr.astype(np.float32) / 127.5 - 1
164
+ arr = np.transpose(arr, [2,0,1])
165
+
166
+ return torch.tensor(arr), info
167
+
168
+
169
+
170
+ def center_crop_arr(pil_image, image_size):
171
+ # We are not on a new enough PIL to support the `reducing_gap`
172
+ # argument, which uses BOX downsampling at powers of two first.
173
+ # Thus, we do it by hand to improve downsample quality.
174
+ WW, HH = pil_image.size
175
+
176
+ while min(*pil_image.size) >= 2 * image_size:
177
+ pil_image = pil_image.resize(
178
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
179
+ )
180
+
181
+ scale = image_size / min(*pil_image.size)
182
+
183
+ pil_image = pil_image.resize(
184
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
185
+ )
186
+
187
+ # at this point, the min of pil_image side is desired image_size
188
+ performed_scale = image_size / min(WW, HH)
189
+
190
+ arr = np.array(pil_image)
191
+ crop_y = (arr.shape[0] - image_size) // 2
192
+ crop_x = (arr.shape[1] - image_size) // 2
193
+
194
+ info = {"performed_scale":performed_scale, 'crop_y':crop_y, 'crop_x':crop_x, "WW":WW, 'HH':HH}
195
+
196
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size], info
197
+
198
+
199
+ def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
200
+ min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
201
+ max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
202
+ smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
203
+
204
+ # We are not on a new enough PIL to support the `reducing_gap`
205
+ # argument, which uses BOX downsampling at powers of two first.
206
+ # Thus, we do it by hand to improve downsample quality.
207
+ while min(*pil_image.size) >= 2 * smaller_dim_size:
208
+ pil_image = pil_image.resize(
209
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
210
+ )
211
+
212
+ scale = smaller_dim_size / min(*pil_image.size)
213
+ pil_image = pil_image.resize(
214
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
215
+ )
216
+
217
+ arr = np.array(pil_image)
218
+ crop_y = random.randrange(arr.shape[0] - image_size + 1)
219
+ crop_x = random.randrange(arr.shape[1] - image_size + 1)
220
+ return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
dataset/catalog.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ class DatasetCatalog:
4
+ def __init__(self, ROOT, which_embedder):
5
+ assert which_embedder in ['clip', 'bert']
6
+
7
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
8
+
9
+
10
+ self.VGGrounding = {
11
+ "target": "dataset.tsv_dataset.TSVDataset",
12
+ "train_params": dict(
13
+ tsv_path=os.path.join(ROOT,'GROUNDING/gqa/tsv/train-00.tsv'),
14
+ )
15
+ }
16
+
17
+
18
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
19
+
20
+
21
+ self.FlickrGrounding = {
22
+ "target": "dataset.tsv_dataset.TSVDataset",
23
+ "train_params":dict(
24
+ tsv_path=os.path.join(ROOT,'GROUNDING/flickr30k/tsv/train-00.tsv'),
25
+ )
26
+ }
27
+
28
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
29
+
30
+ self.SBUGrounding = {
31
+ "target": "dataset.tsv_dataset.TSVDataset",
32
+ "train_params":dict(
33
+ tsv_path=os.path.join(ROOT,'GROUNDING/SBU/tsv/train-00.tsv'),
34
+ )
35
+ }
36
+
37
+
38
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
39
+
40
+
41
+ self.CC3MGrounding = {
42
+ "target": "dataset.tsv_dataset.TSVDataset",
43
+ "train_params":dict(
44
+ tsv_path=os.path.join(ROOT,'GROUNDING/CC3M/tsv/train-00.tsv'),
45
+ )
46
+ }
47
+
48
+
49
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
50
+
51
+
52
+ self.CC12MGrounding = {
53
+ "target": "dataset.tsv_dataset.TSVDataset",
54
+ "train_params":dict(
55
+ tsv_path=os.path.join(ROOT,'GROUNDING/CC12M/tsv/train-00.tsv'),
56
+ )
57
+ }
58
+
59
+
60
+ # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
61
+
62
+ # temp = 'category_embedding_clip.pth' if which_embedder == 'clip' else 'category_embedding_bert.pth'
63
+ # obj365_category_embedding_path = os.path.join(ROOT, 'OBJECTS365', temp)
64
+
65
+ self.Obj365Detection = {
66
+ "target": "dataset.tsv_dataset.TSVDataset",
67
+ "train_params":dict(
68
+ tsv_path=os.path.join(ROOT,'OBJECTS365/tsv/train-00.tsv'),
69
+ ),
70
+ }
71
+
72
+
dataset/cd_dataset.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os, random, math
2
+ from collections import defaultdict
3
+ from copy import deepcopy
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import torchvision.transforms as transforms
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid
12
+ from io import BytesIO
13
+
14
+
15
+
16
+ def not_in_at_all(list1, list2):
17
+ for a in list1:
18
+ if a in list2:
19
+ return False
20
+ return True
21
+
22
+
23
+ def clean_annotations(annotations):
24
+ for anno in annotations:
25
+ anno.pop("segmentation", None)
26
+ anno.pop("area", None)
27
+ anno.pop("iscrowd", None)
28
+ # anno.pop("id", None)
29
+
30
+
31
+ def make_a_sentence(obj_names, clean=False):
32
+
33
+ if clean:
34
+ obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names]
35
+
36
+ caption = ""
37
+ tokens_positive = []
38
+ for obj_name in obj_names:
39
+ start_len = len(caption)
40
+ caption += obj_name
41
+ end_len = len(caption)
42
+ caption += ", "
43
+ tokens_positive.append(
44
+ [[start_len, end_len]] # in real caption, positive tokens can be disjoint, thus using list of list
45
+ )
46
+ caption = caption[:-2] # remove last ", "
47
+
48
+ return caption #, tokens_positive
49
+
50
+
51
+ def check_all_have_same_images(instances_data, stuff_data, caption_data):
52
+ if stuff_data is not None:
53
+ assert instances_data["images"] == stuff_data["images"]
54
+ if caption_data is not None:
55
+ assert instances_data["images"] == caption_data["images"]
56
+
57
+
58
+ class CDDataset(BaseDataset):
59
+ "CD: Caption Detection"
60
+ def __init__(self,
61
+ image_root,
62
+ category_embedding_path,
63
+ instances_json_path = None,
64
+ stuff_json_path = None,
65
+ caption_json_path = None,
66
+ prob_real_caption = 0,
67
+ fake_caption_type = 'empty',
68
+ image_size=256,
69
+ max_images=None,
70
+ min_box_size=0.01,
71
+ max_boxes_per_image=8,
72
+ include_other=False,
73
+ random_crop = False,
74
+ random_flip = True,
75
+ ):
76
+ super().__init__(random_crop, random_flip, image_size)
77
+
78
+ self.image_root = image_root
79
+ self.category_embedding_path = category_embedding_path
80
+ self.instances_json_path = instances_json_path
81
+ self.stuff_json_path = stuff_json_path
82
+ self.caption_json_path = caption_json_path
83
+ self.prob_real_caption = prob_real_caption
84
+ self.fake_caption_type = fake_caption_type
85
+ self.max_images = max_images
86
+ self.min_box_size = min_box_size
87
+ self.max_boxes_per_image = max_boxes_per_image
88
+ self.include_other = include_other
89
+
90
+
91
+ assert fake_caption_type in ["empty", "made"]
92
+ if prob_real_caption > 0:
93
+ assert caption_json_path is not None, "caption json must be given"
94
+
95
+
96
+ # Load all jsons
97
+ with open(instances_json_path, 'r') as f:
98
+ instances_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
99
+ clean_annotations(instances_data["annotations"])
100
+ self.instances_data = instances_data
101
+
102
+ self.stuff_data = None
103
+ if stuff_json_path is not None:
104
+ with open(stuff_json_path, 'r') as f:
105
+ stuff_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
106
+ clean_annotations(stuff_data["annotations"])
107
+ self.stuff_data = stuff_data
108
+
109
+ self.captions_data = None
110
+ if caption_json_path is not None:
111
+ with open(caption_json_path, 'r') as f:
112
+ captions_data = json.load(f) # keys: 'info', 'images', 'licenses', 'categories', 'annotations'
113
+ clean_annotations(captions_data["annotations"])
114
+ self.captions_data = captions_data
115
+
116
+
117
+ # Load preprocessed name embedding
118
+ self.category_embeddings = torch.load(category_embedding_path)
119
+ self.embedding_len = list( self.category_embeddings.values() )[0].shape[0]
120
+
121
+
122
+ # Misc
123
+ self.image_ids = [] # main list for selecting images
124
+ self.image_id_to_filename = {} # file names used to read image
125
+ check_all_have_same_images(self.instances_data, self.stuff_data, self.captions_data)
126
+ for image_data in self.instances_data['images']:
127
+ image_id = image_data['id']
128
+ filename = image_data['file_name']
129
+ self.image_ids.append(image_id)
130
+ self.image_id_to_filename[image_id] = filename
131
+
132
+
133
+ # All category names (including things and stuff)
134
+ self.object_idx_to_name = {}
135
+ for category_data in self.instances_data['categories']:
136
+ self.object_idx_to_name[category_data['id']] = category_data['name']
137
+ if self.stuff_data is not None:
138
+ for category_data in self.stuff_data['categories']:
139
+ self.object_idx_to_name[category_data['id']] = category_data['name']
140
+
141
+
142
+ # Add object data from instances and stuff
143
+ self.image_id_to_objects = defaultdict(list)
144
+ self.select_objects( self.instances_data['annotations'] )
145
+ if self.stuff_data is not None:
146
+ self.select_objects( self.stuff_data['annotations'] )
147
+