Spaces:
Runtime error
Runtime error
Commit
•
b50750a
1
Parent(s):
067f7c8
Update app.py
Browse files
app.py
CHANGED
@@ -26,11 +26,11 @@ USE_TORCH_COMPILE = False
|
|
26 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
27 |
PREVIEW_IMAGES = True
|
28 |
|
29 |
-
dtype = torch.
|
30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
31 |
if torch.cuda.is_available():
|
32 |
-
prior_pipeline = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=
|
33 |
-
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("diffusers/StableCascade-decoder", torch_dtype=
|
34 |
|
35 |
if ENABLE_CPU_OFFLOAD:
|
36 |
prior_pipeline.enable_model_cpu_offload()
|
@@ -46,6 +46,7 @@ if torch.cuda.is_available():
|
|
46 |
if PREVIEW_IMAGES:
|
47 |
previewer = Previewer()
|
48 |
previewer.load_state_dict(torch.load("previewer/previewer_v1_100k.pt")["state_dict"])
|
|
|
49 |
def callback_prior(i, t, latents):
|
50 |
output = previewer(latents)
|
51 |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
|
@@ -81,9 +82,9 @@ def generate(
|
|
81 |
num_images_per_prompt: int = 2,
|
82 |
#profile: gr.OAuthProfile | None = None,
|
83 |
) -> PIL.Image.Image:
|
84 |
-
prior_pipeline.to(
|
85 |
-
decoder_pipeline.to(
|
86 |
-
previewer.eval().requires_grad_(False).to(device).to(dtype)
|
87 |
generator = torch.Generator().manual_seed(seed)
|
88 |
prior_output = prior_pipeline(
|
89 |
prompt=prompt,
|
|
|
26 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
|
27 |
PREVIEW_IMAGES = True
|
28 |
|
29 |
+
dtype = torch.bfloat16
|
30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
31 |
if torch.cuda.is_available():
|
32 |
+
prior_pipeline = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=dtype).to(device)
|
33 |
+
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("diffusers/StableCascade-decoder", torch_dtype=dtype).to(device)
|
34 |
|
35 |
if ENABLE_CPU_OFFLOAD:
|
36 |
prior_pipeline.enable_model_cpu_offload()
|
|
|
46 |
if PREVIEW_IMAGES:
|
47 |
previewer = Previewer()
|
48 |
previewer.load_state_dict(torch.load("previewer/previewer_v1_100k.pt")["state_dict"])
|
49 |
+
previewer.eval().requires_grad_(False).to(device).to(dtype)
|
50 |
def callback_prior(i, t, latents):
|
51 |
output = previewer(latents)
|
52 |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
|
|
|
82 |
num_images_per_prompt: int = 2,
|
83 |
#profile: gr.OAuthProfile | None = None,
|
84 |
) -> PIL.Image.Image:
|
85 |
+
#prior_pipeline.to(device)
|
86 |
+
#decoder_pipeline.to(device)
|
87 |
+
#previewer.eval().requires_grad_(False).to(device).to(dtype)
|
88 |
generator = torch.Generator().manual_seed(seed)
|
89 |
prior_output = prior_pipeline(
|
90 |
prompt=prompt,
|