p1atdev commited on
Commit
a188c05
1 Parent(s): 5928cde

chore: use sdpa

Browse files
Files changed (1) hide show
  1. diffusion.py +10 -2
diffusion.py CHANGED
@@ -1,12 +1,14 @@
1
  from PIL import Image
2
 
3
  import torch
 
4
  from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
5
  StableDiffusionXLPipeline,
6
  )
7
  from diffusers.schedulers.scheduling_euler_ancestral_discrete import (
8
  EulerAncestralDiscreteScheduler,
9
  )
 
10
 
11
  try:
12
  import spaces
@@ -27,9 +29,9 @@ class ImageGenerator:
27
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
28
  model_name,
29
  torch_dtype=torch.float16,
30
- custom_pipeline="lpw_stable_diffusion_xl",
31
  use_safetensors=True,
32
  add_watermarker=False,
 
33
  )
34
  self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
35
  model_name,
@@ -37,7 +39,13 @@ class ImageGenerator:
37
  )
38
 
39
  # xformers
40
- self.pipe.enable_xformers_memory_efficient_attention()
 
 
 
 
 
 
41
 
42
  self.pipe.to("cuda")
43
 
 
1
  from PIL import Image
2
 
3
  import torch
4
+
5
  from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
6
  StableDiffusionXLPipeline,
7
  )
8
  from diffusers.schedulers.scheduling_euler_ancestral_discrete import (
9
  EulerAncestralDiscreteScheduler,
10
  )
11
+ from diffusers.models.attention_processor import AttnProcessor2_0
12
 
13
  try:
14
  import spaces
 
29
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
30
  model_name,
31
  torch_dtype=torch.float16,
 
32
  use_safetensors=True,
33
  add_watermarker=False,
34
+ custom_pipeline="lpw_stable_diffusion_xl",
35
  )
36
  self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
37
  model_name,
 
39
  )
40
 
41
  # xformers
42
+ # self.pipe.enable_xformers_memory_efficient_attention()
43
+ self.pipe.unet.set_attn_processor(AttnProcessor2_0())
44
+
45
+ try:
46
+ self.pipe = torch.compile(self.pipe)
47
+ except Exception as e:
48
+ print("torch.compile is not supported on this system")
49
 
50
  self.pipe.to("cuda")
51