radames HF staff commited on
Commit
4d08072
1 Parent(s): 4b58964

enable sfast in controlnet pipeline

Browse files
Files changed (2) hide show
  1. build-run.sh +1 -1
  2. pipelines/controlnet.py +12 -4
build-run.sh CHANGED
@@ -13,4 +13,4 @@ if [ -z ${PIPELINE+x} ]; then
13
  PIPELINE="controlnet"
14
  fi
15
  echo -e "\033[1;32m\npipeline: $PIPELINE \033[0m"
16
- python3 run.py --port 7860 --host 0.0.0.0 --pipeline $PIPELINE
 
13
  PIPELINE="controlnet"
14
  fi
15
  echo -e "\033[1;32m\npipeline: $PIPELINE \033[0m"
16
+ python3 run.py --port 7860 --host 0.0.0.0 --pipeline $PIPELINE --sfast
pipelines/controlnet.py CHANGED
@@ -173,16 +173,24 @@ class Pipeline:
173
  self.pipe.vae = AutoencoderTiny.from_pretrained(
174
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
175
  ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
176
  self.canny_torch = SobelOperator(device=device)
177
  self.pipe.set_progress_bar_config(disable=True)
178
  self.pipe.to(device=device, dtype=torch_dtype)
179
  if device.type != "mps":
180
  self.pipe.unet.to(memory_format=torch.channels_last)
181
 
182
- # check if computer has less than 64GB of RAM using sys or os
183
- if psutil.virtual_memory().total < 64 * 1024**3:
184
- self.pipe.enable_attention_slicing()
185
-
186
  if args.torch_compile:
187
  self.pipe.unet = torch.compile(
188
  self.pipe.unet, mode="reduce-overhead", fullgraph=True
 
173
  self.pipe.vae = AutoencoderTiny.from_pretrained(
174
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
175
  ).to(device)
176
+
177
+ if args.sfast:
178
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
179
+ compile,
180
+ CompilationConfig,
181
+ )
182
+
183
+ config = CompilationConfig.Default()
184
+ config.enable_xformers = True
185
+ config.enable_triton = True
186
+ config.enable_cuda_graph = True
187
+ self.pipe = compile(self.pipe, config=config)
188
  self.canny_torch = SobelOperator(device=device)
189
  self.pipe.set_progress_bar_config(disable=True)
190
  self.pipe.to(device=device, dtype=torch_dtype)
191
  if device.type != "mps":
192
  self.pipe.unet.to(memory_format=torch.channels_last)
193
 
 
 
 
 
194
  if args.torch_compile:
195
  self.pipe.unet = torch.compile(
196
  self.pipe.unet, mode="reduce-overhead", fullgraph=True