Automatically check the device and supports running directly on Apple M-series chip devices

#28
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -12,6 +12,12 @@ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
12
  from PIL import Image, ImageDraw
13
  import numpy as np
14
 
 
 
 
 
 
 
15
  config_file = hf_hub_download(
16
  "xinsir/controlnet-union-sdxl-1.0",
17
  filename="config_promax.json",
@@ -27,11 +33,11 @@ state_dict = load_state_dict(model_file)
27
  model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
28
  controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
29
  )
30
- model.to(device="cuda", dtype=torch.float16)
31
 
32
  vae = AutoencoderKL.from_pretrained(
33
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
34
- ).to("cuda")
35
 
36
  pipe = StableDiffusionXLFillPipeline.from_pretrained(
37
  "SG161222/RealVisXL_V5.0_Lightning",
@@ -39,7 +45,7 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
39
  vae=vae,
40
  controlnet=model,
41
  variant="fp16",
42
- ).to("cuda")
43
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
@@ -185,7 +191,7 @@ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_
185
  negative_prompt_embeds,
186
  pooled_prompt_embeds,
187
  negative_pooled_prompt_embeds,
188
- ) = pipe.encode_prompt(final_prompt, "cuda", True)
189
 
190
  for image in pipe(
191
  prompt_embeds=prompt_embeds,
 
12
  from PIL import Image, ImageDraw
13
  import numpy as np
14
 
15
+ device = "cpu"
16
+ if torch.cuda.is_available():
17
+ device = "cuda"
18
+ elif torch.backends.mps.is_available():
19
+ device = "mps"
20
+
21
  config_file = hf_hub_download(
22
  "xinsir/controlnet-union-sdxl-1.0",
23
  filename="config_promax.json",
 
33
  model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
34
  controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
35
  )
36
+ model.to(device=device, dtype=torch.float16)
37
 
38
  vae = AutoencoderKL.from_pretrained(
39
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
40
+ ).to(device)
41
 
42
  pipe = StableDiffusionXLFillPipeline.from_pretrained(
43
  "SG161222/RealVisXL_V5.0_Lightning",
 
45
  vae=vae,
46
  controlnet=model,
47
  variant="fp16",
48
+ ).to(device)
49
 
50
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
51
 
 
191
  negative_prompt_embeds,
192
  pooled_prompt_embeds,
193
  negative_pooled_prompt_embeds,
194
+ ) = pipe.encode_prompt(final_prompt, device, True)
195
 
196
  for image in pipe(
197
  prompt_embeds=prompt_embeds,