JingyeChen commited on
Commit
a1ce13a
1 Parent(s): 0fdd5a9
Files changed (1) hide show
  1. app.py +59 -59
app.py CHANGED
@@ -319,66 +319,66 @@ def text_to_image(prompt,keywords,radio,slider_step,slider_guidance,slider_batch
319
  while len(prompt) < 77:
320
  prompt.append(tokenizer.pad_token_id)
321
 
322
- if radio == 'TextDiffuser-2':
323
-
324
- prompts_cond = prompt
325
- prompts_nocond = [tokenizer.pad_token_id]*77
326
-
327
- prompts_cond = [prompts_cond] * slider_batch
328
- prompts_nocond = [prompts_nocond] * slider_batch
329
-
330
- prompts_cond = torch.Tensor(prompts_cond).long().cuda()
331
- prompts_nocond = torch.Tensor(prompts_nocond).long().cuda()
332
-
333
- scheduler = DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="scheduler")
334
- scheduler.set_timesteps(slider_step)
335
- noise = torch.randn((slider_batch, 4, 64, 64)).to("cuda")
336
- input = noise
337
-
338
- encoder_hidden_states_cond = text_encoder(prompts_cond)[0]
339
- encoder_hidden_states_nocond = text_encoder(prompts_nocond)[0]
340
-
341
-
342
- for t in tqdm(scheduler.timesteps):
343
- with torch.no_grad(): # classifier free guidance
344
- noise_pred_cond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_cond[:slider_batch]).sample # b, 4, 64, 64
345
- noise_pred_uncond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_nocond[:slider_batch]).sample # b, 4, 64, 64
346
- noisy_residual = noise_pred_uncond + slider_guidance * (noise_pred_cond - noise_pred_uncond) # b, 4, 64, 64
347
- prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
348
- input = prev_noisy_sample
349
-
350
- # decode
351
- input = 1 / vae.config.scaling_factor * input
352
- images = vae.decode(input, return_dict=False)[0]
353
- width, height = 512, 512
354
- results = []
355
- new_image = Image.new('RGB', (2*width, 2*height))
356
- for index, image in enumerate(images.float()):
357
- image = (image / 2 + 0.5).clamp(0, 1).unsqueeze(0)
358
- image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
359
- image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
360
- results.append(image)
361
- row = index // 2
362
- col = index % 2
363
- new_image.paste(image, (col*width, row*height))
364
- # new_image.save(f'{args.output_dir}/pred_img_{sample_index}_{args.local_rank}.jpg')
365
- # results.insert(0, new_image)
366
- # return new_image
367
- os.system('nvidia-smi')
368
- return tuple(results), composed_prompt
369
-
370
- elif radio == 'TextDiffuser-2-LCM':
371
- generator = torch.Generator(device=pipe.device).manual_seed(random.randint(0,1000))
372
- image = pipe(
373
- prompt=user_prompt,
374
- generator=generator,
375
- # negative_prompt=negative_prompt,
376
- num_inference_steps=1,
377
- guidance_scale=1,
378
- num_images_per_prompt=slider_batch,
379
- ).images
380
- return tuple(image), composed_prompt
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  with gr.Blocks() as demo:
383
 
384
  gr.HTML(
 
319
  while len(prompt) < 77:
320
  prompt.append(tokenizer.pad_token_id)
321
 
322
+ if radio == 'TextDiffuser-2':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ prompts_cond = prompt
325
+ prompts_nocond = [tokenizer.pad_token_id]*77
326
+
327
+ prompts_cond = [prompts_cond] * slider_batch
328
+ prompts_nocond = [prompts_nocond] * slider_batch
329
+
330
+ prompts_cond = torch.Tensor(prompts_cond).long().cuda()
331
+ prompts_nocond = torch.Tensor(prompts_nocond).long().cuda()
332
+
333
+ scheduler = DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="scheduler")
334
+ scheduler.set_timesteps(slider_step)
335
+ noise = torch.randn((slider_batch, 4, 64, 64)).to("cuda")
336
+ input = noise
337
+
338
+ encoder_hidden_states_cond = text_encoder(prompts_cond)[0]
339
+ encoder_hidden_states_nocond = text_encoder(prompts_nocond)[0]
340
+
341
+
342
+ for t in tqdm(scheduler.timesteps):
343
+ with torch.no_grad(): # classifier free guidance
344
+ noise_pred_cond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_cond[:slider_batch]).sample # b, 4, 64, 64
345
+ noise_pred_uncond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_nocond[:slider_batch]).sample # b, 4, 64, 64
346
+ noisy_residual = noise_pred_uncond + slider_guidance * (noise_pred_cond - noise_pred_uncond) # b, 4, 64, 64
347
+ prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
348
+ input = prev_noisy_sample
349
+
350
+ # decode
351
+ input = 1 / vae.config.scaling_factor * input
352
+ images = vae.decode(input, return_dict=False)[0]
353
+ width, height = 512, 512
354
+ results = []
355
+ new_image = Image.new('RGB', (2*width, 2*height))
356
+ for index, image in enumerate(images.float()):
357
+ image = (image / 2 + 0.5).clamp(0, 1).unsqueeze(0)
358
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
359
+ image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB')
360
+ results.append(image)
361
+ row = index // 2
362
+ col = index % 2
363
+ new_image.paste(image, (col*width, row*height))
364
+ # new_image.save(f'{args.output_dir}/pred_img_{sample_index}_{args.local_rank}.jpg')
365
+ # results.insert(0, new_image)
366
+ # return new_image
367
+ os.system('nvidia-smi')
368
+ return tuple(results), composed_prompt
369
+
370
+ elif radio == 'TextDiffuser-2-LCM':
371
+ generator = torch.Generator(device=pipe.device).manual_seed(random.randint(0,1000))
372
+ image = pipe(
373
+ prompt=user_prompt,
374
+ generator=generator,
375
+ # negative_prompt=negative_prompt,
376
+ num_inference_steps=1,
377
+ guidance_scale=1,
378
+ num_images_per_prompt=slider_batch,
379
+ ).images
380
+ return tuple(image), composed_prompt
381
+
382
  with gr.Blocks() as demo:
383
 
384
  gr.HTML(