multimodalart HF staff commited on
Commit
dd02dfa
1 Parent(s): c5c2b01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -24,7 +24,7 @@ CACHE_EXAMPLES = False #torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES"
24
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
25
  USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
- PREVIEW_IMAGES = False
28
 
29
  dtype = torch.bfloat16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -47,10 +47,12 @@ if torch.cuda.is_available():
47
  previewer = Previewer()
48
  previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
49
  previewer.load_state_dict(previewer_state_dict)
50
- def callback_prior(i, t, latents):
 
51
  output = previewer(latents)
52
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
53
- return output
 
54
  callback_steps = 1
55
  else:
56
  previewer = None
@@ -84,7 +86,7 @@ def generate(
84
  profile: gr.OAuthProfile | None = None,
85
  ) -> PIL.Image.Image:
86
 
87
- #previewer.eval().requires_grad_(False).to(device).to(dtype)
88
  prior_pipeline.to(device)
89
  decoder_pipeline.to(device)
90
 
@@ -100,8 +102,8 @@ def generate(
100
  guidance_scale=prior_guidance_scale,
101
  num_images_per_prompt=num_images_per_prompt,
102
  generator=generator,
103
- #callback=callback_prior,
104
- #callback_steps=callback_steps
105
  )
106
  print(prior_output)
107
  if PREVIEW_IMAGES:
 
24
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
25
  USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
+ PREVIEW_IMAGES = True
28
 
29
  dtype = torch.bfloat16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
47
  previewer = Previewer()
48
  previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
49
  previewer.load_state_dict(previewer_state_dict)
50
+ def callback_prior(pipeline, step_index, t, callback_kwargs):
51
+ latents = callback_kwargs["latents"]
52
  output = previewer(latents)
53
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
54
+ callback_kwargs["preview_output"] = output
55
+ return callback_kwargs
56
  callback_steps = 1
57
  else:
58
  previewer = None
 
86
  profile: gr.OAuthProfile | None = None,
87
  ) -> PIL.Image.Image:
88
 
89
+ previewer.eval().requires_grad_(False).to(device).to(dtype)
90
  prior_pipeline.to(device)
91
  decoder_pipeline.to(device)
92
 
 
102
  guidance_scale=prior_guidance_scale,
103
  num_images_per_prompt=num_images_per_prompt,
104
  generator=generator,
105
+ callback_on_step_end=callback_prior,
106
+ callback_on_step_end_tensor_inputs=['latents']
107
  )
108
  print(prior_output)
109
  if PREVIEW_IMAGES: