multimodalart HF staff commited on
Commit
1a89b82
1 Parent(s): 7359460

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -29,8 +29,8 @@ PREVIEW_IMAGES = True
29
  dtype = torch.bfloat16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
  if torch.cuda.is_available():
32
- prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype).to(device)
33
- decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype).to(device)
34
 
35
  if ENABLE_CPU_OFFLOAD:
36
  prior_pipeline.enable_model_cpu_offload()
@@ -45,8 +45,8 @@ if torch.cuda.is_available():
45
 
46
  if PREVIEW_IMAGES:
47
  previewer = Previewer()
48
- previewer.load_state_dict(torch.load("previewer/previewer_v1_100k.pt")["state_dict"])
49
- previewer.eval().requires_grad_(False).to(device).to(dtype)
50
  def callback_prior(i, t, latents):
51
  output = previewer(latents)
52
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
@@ -82,9 +82,10 @@ def generate(
82
  num_images_per_prompt: int = 2,
83
  profile: gr.OAuthProfile | None = None,
84
  ) -> PIL.Image.Image:
85
- #prior_pipeline.to(device)
86
- #decoder_pipeline.to(device)
87
- #previewer.eval().requires_grad_(False).to(device).to(dtype)
 
88
  generator = torch.Generator().manual_seed(seed)
89
  prior_output = prior_pipeline(
90
  prompt=prompt,
 
29
  dtype = torch.bfloat16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
  if torch.cuda.is_available():
32
+ prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
33
+ decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
34
 
35
  if ENABLE_CPU_OFFLOAD:
36
  prior_pipeline.enable_model_cpu_offload()
 
45
 
46
  if PREVIEW_IMAGES:
47
  previewer = Previewer()
48
+ previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
49
+ previewer.load_state_dict(previewer_state_dict)
50
  def callback_prior(i, t, latents):
51
  output = previewer(latents)
52
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
 
82
  num_images_per_prompt: int = 2,
83
  profile: gr.OAuthProfile | None = None,
84
  ) -> PIL.Image.Image:
85
+ previewer.eval().requires_grad_(False).to(device).to(dtype)
86
+ prior_pipeline.to(device)
87
+ decoder_pipeline.to(device)
88
+
89
  generator = torch.Generator().manual_seed(seed)
90
  prior_output = prior_pipeline(
91
  prompt=prompt,