multimodalart HF staff commited on
Commit
a2bfbe9
1 Parent(s): 7c9588b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -46,7 +46,6 @@ if torch.cuda.is_available():
46
  if PREVIEW_IMAGES:
47
  previewer = Previewer()
48
  previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
49
- previewer.eval().requires_grad_(False).to(device).to(dtype)
50
 
51
  def callback_prior(i, t, latents):
52
  output = previewer(latents)
@@ -84,6 +83,7 @@ def generate(
84
  ) -> PIL.Image.Image:
85
  prior_pipeline.to("cuda")
86
  decoder_pipeline.to("cuda")
 
87
  generator = torch.Generator().manual_seed(seed)
88
  prior_output = prior_pipeline(
89
  prompt=prompt,
 
46
  if PREVIEW_IMAGES:
47
  previewer = Previewer()
48
  previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
 
49
 
50
  def callback_prior(i, t, latents):
51
  output = previewer(latents)
 
83
  ) -> PIL.Image.Image:
84
  prior_pipeline.to("cuda")
85
  decoder_pipeline.to("cuda")
86
+ previewer.eval().requires_grad_(False).to(device).to(dtype)
87
  generator = torch.Generator().manual_seed(seed)
88
  prior_output = prior_pipeline(
89
  prompt=prompt,