ameerazam08 commited on
Commit
4546757
Β·
verified Β·
1 Parent(s): dd0c8cd

device changes

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -48,11 +48,15 @@ ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your
48
  # pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
49
  # pipe = pipe.to(device)
50
 
51
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda")
52
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
53
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet).to("cuda")
 
 
54
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
55
 
 
 
56
  # Load resadapter
57
  pipe.load_lora_weights(
58
  hf_hub_download(
@@ -63,6 +67,8 @@ pipe.load_lora_weights(
63
  adapter_name="res_adapter",
64
  )
65
 
 
 
66
 
67
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
68
  if randomize_seed:
 
48
  # pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
49
  # pipe = pipe.to(device)
50
 
51
+
52
+ # Load model.
53
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device)
54
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
55
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet).to(device)
56
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
57
 
58
+
59
+
60
  # Load resadapter
61
  pipe.load_lora_weights(
62
  hf_hub_download(
 
67
  adapter_name="res_adapter",
68
  )
69
 
70
+ pipe = pipe.to(device)
71
+
72
 
73
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
74
  if randomize_seed: