JingyeChen commited on
Commit
08f7d1b
1 Parent(s): 5d34271
Files changed (1) hide show
  1. app.py +83 -54
app.py CHANGED
@@ -82,6 +82,15 @@ unet = UNet2DConditionModel.from_pretrained(
82
  text_encoder.resize_token_embeddings(len(tokenizer))
83
 
84
 
 
 
 
 
 
 
 
 
 
85
  #### for interactive
86
  stack = []
87
  state = 0
@@ -187,7 +196,7 @@ def get_pixels(i, t, evt: gr.SelectData):
187
 
188
 
189
 
190
- def text_to_image(prompt,keywords,slider_step,slider_guidance,slider_batch,slider_temperature,slider_natural):
191
 
192
  global stack
193
  global state
@@ -199,6 +208,7 @@ def text_to_image(prompt,keywords,slider_step,slider_guidance,slider_batch,slide
199
  if slider_natural:
200
  user_prompt = f'<|startoftext|> {user_prompt} <|endoftext|>'
201
  composed_prompt = user_prompt
 
202
  else:
203
  if len(stack) == 0:
204
 
@@ -302,55 +312,67 @@ def text_to_image(prompt,keywords,slider_step,slider_guidance,slider_batch,slide
302
  composed_prompt = user_prompt
303
  prompt = tokenizer.encode(user_prompt)
304
 
305
- prompt = prompt[:77]
306
- while len(prompt) < 77:
307
- prompt.append(tokenizer.pad_token_id)
308
- prompts_cond = prompt
309
- prompts_nocond = [tokenizer.pad_token_id]*77
310
-
311
- prompts_cond = [prompts_cond] * slider_batch
312
- prompts_nocond = [prompts_nocond] * slider_batch
313
-
314
- prompts_cond = torch.Tensor(prompts_cond).long().cuda()
315
- prompts_nocond = torch.Tensor(prompts_nocond).long().cuda()
316
-
317
- scheduler = DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="scheduler")
318
- scheduler.set_timesteps(slider_step)
319
- noise = torch.randn((slider_batch, 4, 64, 64)).to("cuda")
320
- input = noise
321
-
322
- encoder_hidden_states_cond = text_encoder(prompts_cond)[0]
323
- encoder_hidden_states_nocond = text_encoder(prompts_nocond)[0]
324
-
325
-
326
- for t in tqdm(scheduler.timesteps):
327
- with torch.no_grad(): # classifier free guidance
328
- noise_pred_cond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_cond[:slider_batch]).sample # b, 4, 64, 64
329
- noise_pred_uncond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_nocond[:slider_batch]).sample # b, 4, 64, 64
330
- noisy_residual = noise_pred_uncond + slider_guidance * (noise_pred_cond - noise_pred_uncond) # b, 4, 64, 64
331
- prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
332
- input = prev_noisy_sample
333
-
334
- # decode
335
- input = 1 / vae.config.scaling_factor * input
336
- images = vae.decode(input, return_dict=False)[0]
337
- width, height = 512, 512
338
- results = []
339
- new_image = Image.new('RGB', (2*width, 2*height))
340
- for index, image in enumerate(images.float()):
341
- image = (image / 2 + 0.5).clamp(0, 1).unsqueeze(0)
342
- image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
343
- image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
344
- results.append(image)
345
- row = index // 2
346
- col = index % 2
347
- new_image.paste(image, (col*width, row*height))
348
- # new_image.save(f'{args.output_dir}/pred_img_{sample_index}_{args.local_rank}.jpg')
349
- # results.insert(0, new_image)
350
- # return new_image
351
- os.system('nvidia-smi')
352
- return tuple(results), composed_prompt
353
-
 
 
 
 
 
 
 
 
 
 
 
 
354
  with gr.Blocks() as demo:
355
 
356
  gr.HTML(
@@ -359,6 +381,12 @@ with gr.Blocks() as demo:
359
  <h2 style="font-weight: 900; font-size: 2.5rem; margin: 0rem">
360
  TextDiffuser-2: Unleashing the Power of Language Models for Text Rendering
361
  </h2>
 
 
 
 
 
 
362
  <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
363
  [<a href="https://arxiv.org/abs/2311.16465" style="color:blue;">arXiv</a>]
364
  [<a href="https://github.com/microsoft/unilm/tree/master/textdiffuser-2" style="color:blue;">Code</a>]
@@ -376,7 +404,7 @@ with gr.Blocks() as demo:
376
  }
377
  </style>
378
 
379
- <img src="https://i.ibb.co/vmrXRb5/architecture.jpg" alt="textdiffuser-2" class="scaled-image">
380
  </div>
381
  """)
382
 
@@ -403,8 +431,8 @@ with gr.Blocks() as demo:
403
  undo.click(exe_undo, [i,t],[i])
404
  skip_button.click(skip_fun, [i,t])
405
 
406
- # radio = gr.Radio(["Stable Diffusion v2.1", "Stable Diffusion v1.5"], label="Pre-trained Model", value="Stable Diffusion v1.5")
407
- slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser.")
408
  slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
409
  slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled.")
410
  slider_temperature = gr.Slider(minimum=0.1, maximum=2, value=0.7, step=0.1, label="Temperature", info="Control the diversity of layout planner. Higher value indicates more diversity.")
@@ -425,12 +453,13 @@ with gr.Blocks() as demo:
425
 
426
  # gr.Markdown("## Prompt Examples")
427
 
428
- button.click(text_to_image, inputs=[prompt,keywords,slider_step,slider_guidance,slider_batch,slider_temperature,slider_natural], outputs=[output, composed_prompt])
429
 
430
  gr.Markdown("## Prompt Examples")
431
  gr.Examples(
432
  [
433
  ["A beautiful city skyline stamp of Shanghai", ""],
 
434
  ["A book cover named summer vibe", ""],
435
  ],
436
  prompt,
 
82
  text_encoder.resize_token_embeddings(len(tokenizer))
83
 
84
 
85
+ #### load lcm components
86
+ model_id = "lambdalabs/sd-pokemon-diffusers"
87
+ lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
88
+ pipe = DiffusionPipeline.from_pretrained(model_id, unet=unet, tokenizer=tokenizer, text_encoder=text_encoder)
89
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
90
+ pipe.load_lora_weights(lcm_lora_id)
91
+ pipe.to(device="cuda")
92
+
93
+
94
  #### for interactive
95
  stack = []
96
  state = 0
 
196
 
197
 
198
 
199
+ def text_to_image(prompt,keywords,radio,slider_step,slider_guidance,slider_batch,slider_temperature,slider_natural):
200
 
201
  global stack
202
  global state
 
208
  if slider_natural:
209
  user_prompt = f'<|startoftext|> {user_prompt} <|endoftext|>'
210
  composed_prompt = user_prompt
211
+ prompt = tokenizer.encode(user_prompt)
212
  else:
213
  if len(stack) == 0:
214
 
 
312
  composed_prompt = user_prompt
313
  prompt = tokenizer.encode(user_prompt)
314
 
315
+ if radio == 'TextDiffuser-2':
316
+ prompt = prompt[:77]
317
+ while len(prompt) < 77:
318
+ prompt.append(tokenizer.pad_token_id)
319
+ prompts_cond = prompt
320
+ prompts_nocond = [tokenizer.pad_token_id]*77
321
+
322
+ prompts_cond = [prompts_cond] * slider_batch
323
+ prompts_nocond = [prompts_nocond] * slider_batch
324
+
325
+ prompts_cond = torch.Tensor(prompts_cond).long().cuda()
326
+ prompts_nocond = torch.Tensor(prompts_nocond).long().cuda()
327
+
328
+ scheduler = DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="scheduler")
329
+ scheduler.set_timesteps(slider_step)
330
+ noise = torch.randn((slider_batch, 4, 64, 64)).to("cuda")
331
+ input = noise
332
+
333
+ encoder_hidden_states_cond = text_encoder(prompts_cond)[0]
334
+ encoder_hidden_states_nocond = text_encoder(prompts_nocond)[0]
335
+
336
+
337
+ for t in tqdm(scheduler.timesteps):
338
+ with torch.no_grad(): # classifier free guidance
339
+ noise_pred_cond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_cond[:slider_batch]).sample # b, 4, 64, 64
340
+ noise_pred_uncond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_nocond[:slider_batch]).sample # b, 4, 64, 64
341
+ noisy_residual = noise_pred_uncond + slider_guidance * (noise_pred_cond - noise_pred_uncond) # b, 4, 64, 64
342
+ prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
343
+ input = prev_noisy_sample
344
+
345
+ # decode
346
+ input = 1 / vae.config.scaling_factor * input
347
+ images = vae.decode(input, return_dict=False)[0]
348
+ width, height = 512, 512
349
+ results = []
350
+ new_image = Image.new('RGB', (2*width, 2*height))
351
+ for index, image in enumerate(images.float()):
352
+ image = (image / 2 + 0.5).clamp(0, 1).unsqueeze(0)
353
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
354
+ image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
355
+ results.append(image)
356
+ row = index // 2
357
+ col = index % 2
358
+ new_image.paste(image, (col*width, row*height))
359
+ # new_image.save(f'{args.output_dir}/pred_img_{sample_index}_{args.local_rank}.jpg')
360
+ # results.insert(0, new_image)
361
+ # return new_image
362
+ os.system('nvidia-smi')
363
+ return tuple(results), composed_prompt
364
+
365
+ elif radio == 'TextDiffuser-2-LCM':
366
+ generator = torch.Generator(device=pipe.device).manual_seed(random.randint(0,1000))
367
+ image = pipe(
368
+ prompt=user_prompt,
369
+ generator=generator,
370
+ # negative_prompt=negative_prompt,
371
+ num_inference_steps=slider_step,
372
+ guidance_scale=1,
373
+ ).images[0]
374
+ return tuple([image]), composed_prompt
375
+
376
  with gr.Blocks() as demo:
377
 
378
  gr.HTML(
 
381
  <h2 style="font-weight: 900; font-size: 2.5rem; margin: 0rem">
382
  TextDiffuser-2: Unleashing the Power of Language Models for Text Rendering
383
  </h2>
384
+ <h2 style="font-weight: 460; font-size: 1.1rem; margin: 0rem">
385
+ <a href="https://jingyechen.github.io/">Jingye Chen</a>, <a href="https://hypjudy.github.io/website/">Yupan Huang</a>, <a href="https://scholar.google.com/citations?user=0LTZGhUAAAAJ&hl=en">Tengchao Lv</a>, <a href="https://www.microsoft.com/en-us/research/people/lecu/">Lei Cui</a>, <a href="https://cqf.io/">Qifeng Chen</a>, <a href="https://thegenerality.com/">Furu Wei</a>
386
+ </h2>
387
+ <h2 style="font-weight: 460; font-size: 1.1rem; margin: 0rem">
388
+ HKUST, Sun Yat-sen University, Microsoft Research
389
+ </h2>
390
  <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
391
  [<a href="https://arxiv.org/abs/2311.16465" style="color:blue;">arXiv</a>]
392
  [<a href="https://github.com/microsoft/unilm/tree/master/textdiffuser-2" style="color:blue;">Code</a>]
 
404
  }
405
  </style>
406
 
407
+ <img src="https://i.ibb.co/56JVg5j/architecture.jpg" alt="textdiffuser-2" class="scaled-image">
408
  </div>
409
  """)
410
 
 
431
  undo.click(exe_undo, [i,t],[i])
432
  skip_button.click(skip_fun, [i,t])
433
 
434
+ radio = gr.Radio(["TextDiffuser-2", "TextDiffuser-2-LCM"], label="Choices of models", value="TextDiffuser-2")
435
+ slider_step = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Sampling step", info="The sampling step for TextDiffuser-2.")
436
  slider_guidance = gr.Slider(minimum=1, maximum=9, value=7.5, step=0.5, label="Scale of classifier-free guidance", info="The scale of classifier-free guidance and is set to 7.5 in default.")
437
  slider_batch = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Batch size", info="The number of images to be sampled.")
438
  slider_temperature = gr.Slider(minimum=0.1, maximum=2, value=0.7, step=0.1, label="Temperature", info="Control the diversity of layout planner. Higher value indicates more diversity.")
 
453
 
454
  # gr.Markdown("## Prompt Examples")
455
 
456
+ button.click(text_to_image, inputs=[prompt,keywords,radio,slider_step,slider_guidance,slider_batch,slider_temperature,slider_natural], outputs=[output, composed_prompt])
457
 
458
  gr.Markdown("## Prompt Examples")
459
  gr.Examples(
460
  [
461
  ["A beautiful city skyline stamp of Shanghai", ""],
462
+ ["A stamp of U.S.A.", ""],
463
  ["A book cover named summer vibe", ""],
464
  ],
465
  prompt,