multimodalart HF staff commited on
Commit
7db2e08
1 Parent(s): 95d3f19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -14
app.py CHANGED
@@ -43,15 +43,14 @@ if torch.cuda.is_available():
43
  decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
44
 
45
  if PREVIEW_IMAGES:
46
- pass
47
- # previewer = Previewer()
48
- # previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
49
- # previewer.eval().requires_grad_(False).to(device).to(dtype)
50
 
51
- # def callback_prior(i, t, latents):
52
- # output = previewer(latents)
53
- # output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
54
- # return output
55
 
56
  else:
57
  previewer = None
@@ -97,12 +96,12 @@ def generate(
97
  callback=callback_prior,
98
  )
99
 
100
- #if PREVIEW_IMAGES:
101
- # for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
102
- # r = next(prior_output)
103
- # if isinstance(r, list):
104
- # yield r
105
- # prior_output = r
106
 
107
  decoder_output = decoder_pipeline(
108
  image_embeddings=prior_output.image_embeddings,
 
43
  decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
44
 
45
  if PREVIEW_IMAGES:
46
+ previewer = Previewer()
47
+ previewer.load_state_dict(torch.load("previewer/previewer_v1_100k.pt")["state_dict"])
48
+ previewer.eval().requires_grad_(False).to(device).to(dtype)
 
49
 
50
+ def callback_prior(i, t, latents):
51
+ output = previewer(latents)
52
+ output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
53
+ return output
54
 
55
  else:
56
  previewer = None
 
96
  callback=callback_prior,
97
  )
98
 
99
+ if PREVIEW_IMAGES:
100
+ for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
101
+ r = next(prior_output)
102
+ if isinstance(r, list):
103
+ yield r[0]
104
+ prior_output = r
105
 
106
  decoder_output = decoder_pipeline(
107
  image_embeddings=prior_output.image_embeddings,