Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -22,8 +22,10 @@ if not torch.cuda.is_available():
|
|
22 |
|
23 |
MAX_SEED = np.iinfo(np.int32).max
|
24 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
25 |
-
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "
|
26 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
|
|
|
|
27 |
|
28 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
29 |
|
@@ -105,6 +107,12 @@ if torch.cuda.is_available():
|
|
105 |
print("Using DALL-E 3 Consistency Decoder")
|
106 |
pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
# speed-up T5
|
109 |
pipe.text_encoder.to_bettertransformer()
|
110 |
|
@@ -125,6 +133,9 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
|
125 |
return seed
|
126 |
|
127 |
|
|
|
|
|
|
|
128 |
def generate(
|
129 |
prompt: str,
|
130 |
negative_prompt: str = "",
|
|
|
22 |
|
23 |
MAX_SEED = np.iinfo(np.int32).max
|
24 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
25 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "3000"))
|
26 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
27 |
+
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
28 |
+
PORT = int(os.getenv("DEMO_PORT", "15432"))
|
29 |
|
30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
31 |
|
|
|
107 |
print("Using DALL-E 3 Consistency Decoder")
|
108 |
pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
109 |
|
110 |
+
if ENABLE_CPU_OFFLOAD:
|
111 |
+
pipe.enable_model_cpu_offload()
|
112 |
+
else:
|
113 |
+
pipe.to(device)
|
114 |
+
print("Loaded on Device!")
|
115 |
+
|
116 |
# speed-up T5
|
117 |
pipe.text_encoder.to_bettertransformer()
|
118 |
|
|
|
133 |
return seed
|
134 |
|
135 |
|
136 |
+
@torch.no_grad()
|
137 |
+
@torch.inference_mode()
|
138 |
+
@spaces.GPU(duration=30)
|
139 |
def generate(
|
140 |
prompt: str,
|
141 |
negative_prompt: str = "",
|