Spaces:
Runtime error
Runtime error
Samuel L Meyers
commited on
Commit
·
78fd303
1
Parent(s):
826a722
Run on cpu anyway, even if it takes 1m years/img
Browse files
app.py
CHANGED
@@ -12,6 +12,8 @@ from PIL import Image
|
|
12 |
import torch
|
13 |
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
|
14 |
from sa_solver_diffusers import SASolverScheduler
|
|
|
|
|
15 |
|
16 |
|
17 |
DESCRIPTION = """data:image/s3,"s3://crabby-images/a33ad/a33ad89ac68c278d7c95bd62986f7b90b70f3896" alt="Logo"
|
@@ -100,26 +102,11 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str
|
|
100 |
return p.replace("{prompt}", positive), n + negative
|
101 |
|
102 |
|
103 |
-
|
104 |
-
pipe = PixArtAlphaPipeline.from_pretrained(
|
105 |
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
106 |
-
torch_dtype=torch.float16,
|
107 |
use_safetensors=True,
|
108 |
)
|
109 |
|
110 |
-
if ENABLE_CPU_OFFLOAD:
|
111 |
-
pipe.enable_model_cpu_offload()
|
112 |
-
else:
|
113 |
-
pipe.to(device)
|
114 |
-
print("Loaded on Device!")
|
115 |
-
|
116 |
-
# speed-up T5
|
117 |
-
pipe.text_encoder.to_bettertransformer()
|
118 |
-
|
119 |
-
if USE_TORCH_COMPILE:
|
120 |
-
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
|
121 |
-
print("Model Compiled!")
|
122 |
-
|
123 |
|
124 |
def save_image(img):
|
125 |
unique_name = str(uuid.uuid4()) + ".png"
|
|
|
12 |
import torch
|
13 |
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
|
14 |
from sa_solver_diffusers import SASolverScheduler
|
15 |
+
from transformers import AutoTokenizer
|
16 |
+
from transformers import pipeline as pipe
|
17 |
|
18 |
|
19 |
DESCRIPTION = """data:image/s3,"s3://crabby-images/a33ad/a33ad89ac68c278d7c95bd62986f7b90b70f3896" alt="Logo"
|
|
|
102 |
return p.replace("{prompt}", positive), n + negative
|
103 |
|
104 |
|
105 |
+
pipe = PixArtAlphaPipeline.from_pretrained(
|
|
|
106 |
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
|
|
107 |
use_safetensors=True,
|
108 |
)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
def save_image(img):
|
112 |
unique_name = str(uuid.uuid4()) + ".png"
|