Omnibus commited on
Commit
5b2cf12
1 Parent(s): a05e88c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -15,10 +15,10 @@ ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your s
15
  pipe_box=[]
16
  @spaces.GPU
17
  def init():
18
- device="cuda"
19
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
20
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
21
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
22
  # Ensure sampler uses "trailing" timesteps.
23
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
24
  pipe_box.append(pipe)
 
15
  pipe_box=[]
16
  @spaces.GPU
17
  def init():
18
+ device="cuda:0"
19
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
20
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
21
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device)
22
  # Ensure sampler uses "trailing" timesteps.
23
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
24
  pipe_box.append(pipe)