silentchen commited on
Commit
848367f
β€’
1 Parent(s): 4432a94

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -134
app.py CHANGED
@@ -11,7 +11,7 @@ import math
11
  from utils import compute_ca_loss
12
  from gradio import processing_utils
13
  from typing import Optional
14
-
15
  import warnings
16
 
17
  import sys
@@ -67,96 +67,7 @@ def draw_box(boxes=[], texts=[], img=None):
67
  fill=(255, 255, 255))
68
  return img
69
 
70
- '''
71
- inference model
72
- '''
73
-
74
- def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
75
- uncond_input = tokenizer(
76
- [""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
77
- )
78
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
79
-
80
- input_ids = tokenizer(
81
- prompt,
82
- padding="max_length",
83
- truncation=True,
84
- max_length=tokenizer.model_max_length,
85
- return_tensors="pt",
86
- ).input_ids[0].unsqueeze(0).to(device)
87
- # text_embeddings = text_encoder(input_ids)[0]
88
- text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]])
89
- # text_embeddings[1, 1, :] = text_embeddings[1, 2, :]
90
- generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise
91
-
92
- latents = torch.randn(
93
- (batch_size, 4, 64, 64),
94
- generator=generator,
95
- ).to(device)
96
-
97
- noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
98
-
99
- # generator = torch.Generator("cuda").manual_seed(1024)
100
- noise_scheduler.set_timesteps(51)
101
-
102
- latents = latents * noise_scheduler.init_noise_sigma
103
-
104
- loss = torch.tensor(10000)
105
-
106
- for index, t in enumerate(noise_scheduler.timesteps):
107
- iteration = 0
108
-
109
- while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step:
110
- latents = latents.requires_grad_(True)
111
-
112
- # latent_model_input = torch.cat([latents] * 2)
113
- latent_model_input = latents
114
 
115
- latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
116
- noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
117
- unet(latent_model_input, t, encoder_hidden_states=text_encoder(input_ids)[0])
118
-
119
- # update latents with guidence from gaussian blob
120
-
121
- loss = compute_ca_loss(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes,
122
- object_positions=object_positions) * loss_scale
123
-
124
- print(loss.item() / loss_scale)
125
-
126
- grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
127
-
128
- latents = latents - grad_cond * noise_scheduler.sigmas[index] ** 2
129
- iteration += 1
130
- torch.cuda.empty_cache()
131
- torch.cuda.empty_cache()
132
-
133
-
134
- with torch.no_grad():
135
-
136
- latent_model_input = torch.cat([latents] * 2)
137
-
138
- latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
139
- noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
140
- unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
141
-
142
- noise_pred = noise_pred.sample
143
-
144
- # perform classifier-free guidance
145
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
146
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
147
-
148
- latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
149
- torch.cuda.empty_cache()
150
- # Decode image
151
- with torch.no_grad():
152
- # print("decode image")
153
- latents = 1 / 0.18215 * latents
154
- image = vae.decode(latents).sample
155
- image = (image / 2 + 0.5).clamp(0, 1)
156
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
157
- images = (image * 255).round().astype("uint8")
158
- pil_images = [Image.fromarray(image) for image in images]
159
- return pil_images
160
 
161
  def get_concat(ims):
162
  if len(ims) == 1:
@@ -172,42 +83,6 @@ def get_concat(ims):
172
  return dst
173
 
174
 
175
- def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
176
- loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
177
- state):
178
- if 'boxes' not in state:
179
- state['boxes'] = []
180
- boxes = state['boxes']
181
- grounding_texts = [x.strip() for x in grounding_texts.split(';')]
182
- # assert len(boxes) == len(grounding_texts)
183
- if len(boxes) != len(grounding_texts):
184
- if len(boxes) < len(grounding_texts):
185
- raise ValueError("""The number of boxes should be equal to the number of grounding objects.
186
- Number of boxes drawn: {}, number of grounding tokens: {}.
187
- Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
188
- grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
189
-
190
- boxes = (np.asarray(boxes) / 512).tolist()
191
- boxes = [[box] for box in boxes]
192
- grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
193
- language_instruction_list = language_instruction.strip('.').split(' ')
194
- object_positions = []
195
- for obj in grounding_texts:
196
- obj_position = []
197
- for word in obj.split(' '):
198
- obj_first_index = language_instruction_list.index(word) + 1
199
- obj_position.append(obj_first_index)
200
- object_positions.append(obj_position)
201
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
202
-
203
- gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, guidance_scale)
204
-
205
- blank_samples = batch_size % 2 if batch_size > 1 else 0
206
- gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
207
- + [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
208
- + [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
209
-
210
- return gen_images + [state]
211
 
212
 
213
  def binarize(x):
@@ -251,8 +126,9 @@ def center_crop(img, HW=None, tgt_size=(512, 512)):
251
 
252
  def draw(input, grounding_texts, new_image_trigger, state):
253
  if type(input) == dict:
254
- image = input['image']
255
- mask = input['mask']
 
256
  else:
257
  mask = input
258
  if mask.ndim == 3:
@@ -307,7 +183,7 @@ def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False):
307
  if task != 'Grounded Inpainting':
308
  sketch_pad_trigger = sketch_pad_trigger + 1
309
  blank_samples = batch_size % 2 if batch_size > 1 else 0
310
- out_images = [gr.Image.update(value=None, visible=True) for i in range(batch_size)]
311
  # state = {}
312
  return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}]
313
 
@@ -387,6 +263,139 @@ def main():
387
  text_encoder.to(device)
388
  vae.to(device)
389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  with Blocks(
391
  css=css,
392
  analytics_enabled=False,
@@ -418,7 +427,7 @@ def main():
418
 
419
 
420
  with gr.Row():
421
- sketch_pad = gr.Paint(label="Sketch Pad", elem_id="img2img_image", source='canvas', shape=(512, 512))
422
  out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
423
  out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
424
 
@@ -479,7 +488,7 @@ def main():
479
  inputs=sketch_pad_trigger,
480
  outputs=sketch_pad_trigger,
481
  queue=False)
482
- sketch_pad.edit(
483
  draw,
484
  inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
485
  outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
@@ -519,13 +528,13 @@ def main():
519
  None,
520
  None,
521
  sketch_pad_resize_trigger,
522
- _js=rescale_js,
523
  queue=False)
524
  init_white_trigger.change(
525
  None,
526
  None,
527
  init_white_trigger,
528
- _js=rescale_js,
529
  queue=False)
530
 
531
  with gr.Column():
@@ -546,7 +555,7 @@ def main():
546
  description = """<p> The source codes of the demo are modified based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GlIGen</a>. Thanks! </p>"""
547
  gr.HTML(description)
548
 
549
- demo.queue(concurrency_count=1, api_open=False)
550
  demo.launch(share=False, show_api=False, show_error=True)
551
 
552
  if __name__ == '__main__':
 
11
  from utils import compute_ca_loss
12
  from gradio import processing_utils
13
  from typing import Optional
14
+ import spaces
15
  import warnings
16
 
17
  import sys
 
67
  fill=(255, 255, 255))
68
  return img
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def get_concat(ims):
73
  if len(ims) == 1:
 
83
  return dst
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  def binarize(x):
 
126
 
127
  def draw(input, grounding_texts, new_image_trigger, state):
128
  if type(input) == dict:
129
+ # import pdb; pdb.set_trace()
130
+ # image = input['composite']
131
+ mask = input['composite']
132
  else:
133
  mask = input
134
  if mask.ndim == 3:
 
183
  if task != 'Grounded Inpainting':
184
  sketch_pad_trigger = sketch_pad_trigger + 1
185
  blank_samples = batch_size % 2 if batch_size > 1 else 0
186
+ out_images = [gr.Image.change(value=None, visible=True) for i in range(batch_size)]
187
  # state = {}
188
  return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}]
189
 
 
263
  text_encoder.to(device)
264
  vae.to(device)
265
 
266
+ def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
267
+ loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
268
+ state):
269
+ if 'boxes' not in state:
270
+ state['boxes'] = []
271
+ boxes = state['boxes']
272
+ grounding_texts = [x.strip() for x in grounding_texts.split(';')]
273
+ # assert len(boxes) == len(grounding_texts)
274
+ if len(boxes) != len(grounding_texts):
275
+ if len(boxes) < len(grounding_texts):
276
+ raise ValueError("""The number of boxes should be equal to the number of grounding objects.
277
+ Number of boxes drawn: {}, number of grounding tokens: {}.
278
+ Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
279
+ grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
280
+
281
+ boxes = (np.asarray(boxes) / 512).tolist()
282
+ boxes = [[box] for box in boxes]
283
+ grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
284
+ language_instruction_list = language_instruction.strip('.').split(' ')
285
+ object_positions = []
286
+ for obj in grounding_texts:
287
+ obj_position = []
288
+ for word in obj.split(' '):
289
+ obj_first_index = language_instruction_list.index(word) + 1
290
+ obj_position.append(obj_first_index)
291
+ object_positions.append(obj_position)
292
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
293
+
294
+ gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes,
295
+ object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed,
296
+ guidance_scale)
297
+
298
+ blank_samples = batch_size % 2 if batch_size > 1 else 0
299
+ gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
300
+ + [gr.Image.change(fn=None, show_api=True) for _ in range(blank_samples)] \
301
+ + [gr.Image.change(fn=None, show_api=False) for _ in range(4 - batch_size - blank_samples)]
302
+
303
+ return gen_images + [state]
304
+
305
+ '''
306
+ inference model
307
+ '''
308
+
309
+ @spaces.GPU(duration=180)
310
+ def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, object_positions, batch_size, loss_scale,
311
+ loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale):
312
+ uncond_input = tokenizer(
313
+ [""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
314
+ )
315
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
316
+
317
+ input_ids = tokenizer(
318
+ prompt,
319
+ padding="max_length",
320
+ truncation=True,
321
+ max_length=tokenizer.model_max_length,
322
+ return_tensors="pt",
323
+ ).input_ids[0].unsqueeze(0).to(device)
324
+ # text_embeddings = text_encoder(input_ids)[0]
325
+ text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]])
326
+ # text_embeddings[1, 1, :] = text_embeddings[1, 2, :]
327
+ generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise
328
+
329
+ latents = torch.randn(
330
+ (batch_size, 4, 64, 64),
331
+ generator=generator,
332
+ ).to(device)
333
+
334
+ noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
335
+ num_train_timesteps=1000)
336
+
337
+ # generator = torch.Generator("cuda").manual_seed(1024)
338
+ noise_scheduler.set_timesteps(51)
339
+
340
+ latents = latents * noise_scheduler.init_noise_sigma
341
+
342
+ loss = torch.tensor(10000)
343
+
344
+ for index, t in enumerate(noise_scheduler.timesteps):
345
+ iteration = 0
346
+
347
+ while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step:
348
+ latents = latents.requires_grad_(True)
349
+
350
+ # latent_model_input = torch.cat([latents] * 2)
351
+ latent_model_input = latents
352
+
353
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
354
+ noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
355
+ unet(latent_model_input, t, encoder_hidden_states=text_encoder(input_ids)[0])
356
+
357
+ # update latents with guidence from gaussian blob
358
+
359
+ loss = compute_ca_loss(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes,
360
+ object_positions=object_positions) * loss_scale
361
+
362
+ print(loss.item() / loss_scale)
363
+
364
+ grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0]
365
+
366
+ latents = latents - grad_cond * noise_scheduler.sigmas[index] ** 2
367
+ iteration += 1
368
+ torch.cuda.empty_cache()
369
+ torch.cuda.empty_cache()
370
+
371
+ with torch.no_grad():
372
+
373
+ latent_model_input = torch.cat([latents] * 2)
374
+
375
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
376
+ noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \
377
+ unet(latent_model_input, t, encoder_hidden_states=text_embeddings)
378
+
379
+ noise_pred = noise_pred.sample
380
+
381
+ # perform classifier-free guidance
382
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
383
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
384
+
385
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
386
+ torch.cuda.empty_cache()
387
+ # Decode image
388
+ with torch.no_grad():
389
+ # print("decode image")
390
+ latents = 1 / 0.18215 * latents
391
+ image = vae.decode(latents).sample
392
+ image = (image / 2 + 0.5).clamp(0, 1)
393
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
394
+ images = (image * 255).round().astype("uint8")
395
+ pil_images = [Image.fromarray(image) for image in images]
396
+ return pil_images
397
+
398
+
399
  with Blocks(
400
  css=css,
401
  analytics_enabled=False,
 
427
 
428
 
429
  with gr.Row():
430
+ sketch_pad = gr.Paint(label="Sketch Pad", container=False, layers=False, scale=1, elem_id="img2img_image")
431
  out_imagebox = gr.Image(type="pil", label="Parsed Sketch Pad")
432
  out_gen_1 = gr.Image(type="pil", visible=True, label="Generated Image")
433
 
 
488
  inputs=sketch_pad_trigger,
489
  outputs=sketch_pad_trigger,
490
  queue=False)
491
+ sketch_pad.change(
492
  draw,
493
  inputs=[sketch_pad, grounding_instruction, sketch_pad_resize_trigger, state],
494
  outputs=[out_imagebox, sketch_pad_resize_trigger, image_scale, state],
 
528
  None,
529
  None,
530
  sketch_pad_resize_trigger,
531
+ js=rescale_js,
532
  queue=False)
533
  init_white_trigger.change(
534
  None,
535
  None,
536
  init_white_trigger,
537
+ js=rescale_js,
538
  queue=False)
539
 
540
  with gr.Column():
 
555
  description = """<p> The source codes of the demo are modified based on the <a href="https://huggingface.co/spaces/gligen/demo/tree/main">GlIGen</a>. Thanks! </p>"""
556
  gr.HTML(description)
557
 
558
+ demo.queue(api_open=False)
559
  demo.launch(share=False, show_api=False, show_error=True)
560
 
561
  if __name__ == '__main__':