radames commited on
Commit
5fa0aa7
1 Parent(s): 0290645

option to enable taesd sdxl

Browse files
Files changed (1) hide show
  1. app.py +11 -0
app.py CHANGED
@@ -2,6 +2,7 @@ from diffusers import (
2
  StableDiffusionXLPipeline,
3
  EulerDiscreteScheduler,
4
  UNet2DConditionModel,
 
5
  )
6
  import torch
7
  import os
@@ -19,6 +20,7 @@ BASE = "stabilityai/stable-diffusion-xl-base-1.0"
19
  REPO = "ByteDance/SDXL-Lightning"
20
  # 1-step
21
  CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
 
22
 
23
  # {
24
  # "1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
@@ -30,6 +32,8 @@ CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
30
 
31
  SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1"
32
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
 
 
33
  # check if MPS is available OSX only M1/M2/M3 chips
34
 
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -38,6 +42,7 @@ torch_dtype = torch.float16
38
 
39
  print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
40
  print(f"SFAST_COMPILE: {SFAST_COMPILE}")
 
41
  print(f"device: {device}")
42
 
43
 
@@ -49,6 +54,12 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
49
  BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
50
  ).to("cuda")
51
 
 
 
 
 
 
 
52
  # Ensure sampler uses "trailing" timesteps.
53
  pipe.scheduler = EulerDiscreteScheduler.from_config(
54
  pipe.scheduler.config, timestep_spacing="trailing"
 
2
  StableDiffusionXLPipeline,
3
  EulerDiscreteScheduler,
4
  UNet2DConditionModel,
5
+ AutoencoderTiny,
6
  )
7
  import torch
8
  import os
 
20
  REPO = "ByteDance/SDXL-Lightning"
21
  # 1-step
22
  CHECKPOINT = "sdxl_lightning_2step_unet.safetensors"
23
+ taesd_model = "madebyollin/taesdxl"
24
 
25
  # {
26
  # "1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
 
32
 
33
  SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1"
34
  SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
35
+ USE_TAESD = os.environ.get("USE_TAESD", "0") == "1"
36
+
37
  # check if MPS is available OSX only M1/M2/M3 chips
38
 
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
42
 
43
  print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
44
  print(f"SFAST_COMPILE: {SFAST_COMPILE}")
45
+ print(f"USE_TAESD: {USE_TAESD}")
46
  print(f"device: {device}")
47
 
48
 
 
54
  BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False
55
  ).to("cuda")
56
 
57
+ if USE_TAESD:
58
+ pipe.vae = AutoencoderTiny.from_pretrained(
59
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
60
+ ).to(device)
61
+
62
+
63
  # Ensure sampler uses "trailing" timesteps.
64
  pipe.scheduler = EulerDiscreteScheduler.from_config(
65
  pipe.scheduler.config, timestep_spacing="trailing"