Jordan Legg commited on
Commit
6f5f495
β€’
1 Parent(s): 8a31b39

turn off xformers for cpu

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -10,7 +10,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
10
  # Load the model in FP16
11
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
12
 
13
- # Move the pipeline to GPU
14
  pipe = pipe.to(device)
15
 
16
  # Convert text encoders to full precision
@@ -18,12 +18,21 @@ pipe.text_encoder = pipe.text_encoder.to(torch.float32)
18
  if hasattr(pipe, 'text_encoder_2'):
19
  pipe.text_encoder_2 = pipe.text_encoder_2.to(torch.float32)
20
 
21
- # Enable memory efficient attention if available
22
- if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
23
- pipe.enable_xformers_memory_efficient_attention()
 
 
 
 
24
 
25
- # Compile the UNet for potential speedups
26
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
 
 
 
 
 
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  MAX_IMAGE_SIZE = 2048
 
10
  # Load the model in FP16
11
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16)
12
 
13
+ # Move the pipeline to GPU if available
14
  pipe = pipe.to(device)
15
 
16
  # Convert text encoders to full precision
 
18
  if hasattr(pipe, 'text_encoder_2'):
19
  pipe.text_encoder_2 = pipe.text_encoder_2.to(torch.float32)
20
 
21
+ # Enable memory efficient attention if available and on CUDA
22
+ if device == "cuda" and hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
23
+ try:
24
+ pipe.enable_xformers_memory_efficient_attention()
25
+ print("xformers memory efficient attention enabled")
26
+ except Exception as e:
27
+ print(f"Could not enable memory efficient attention: {e}")
28
 
29
+ # Compile the UNet for potential speedups if on CUDA
30
+ if device == "cuda":
31
+ try:
32
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
33
+ print("UNet compiled for potential speedups")
34
+ except Exception as e:
35
+ print(f"Could not compile UNet: {e}")
36
 
37
  MAX_SEED = np.iinfo(np.int32).max
38
  MAX_IMAGE_SIZE = 2048