Spaces:
Runtime error
Runtime error
Commit
•
7db2e08
1
Parent(s):
95d3f19
Update app.py
Browse files
app.py
CHANGED
@@ -43,15 +43,14 @@ if torch.cuda.is_available():
|
|
43 |
decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
|
44 |
|
45 |
if PREVIEW_IMAGES:
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
# previewer.eval().requires_grad_(False).to(device).to(dtype)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
|
56 |
else:
|
57 |
previewer = None
|
@@ -97,12 +96,12 @@ def generate(
|
|
97 |
callback=callback_prior,
|
98 |
)
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
|
107 |
decoder_output = decoder_pipeline(
|
108 |
image_embeddings=prior_output.image_embeddings,
|
|
|
43 |
decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
|
44 |
|
45 |
if PREVIEW_IMAGES:
|
46 |
+
previewer = Previewer()
|
47 |
+
previewer.load_state_dict(torch.load("previewer/previewer_v1_100k.pt")["state_dict"])
|
48 |
+
previewer.eval().requires_grad_(False).to(device).to(dtype)
|
|
|
49 |
|
50 |
+
def callback_prior(i, t, latents):
|
51 |
+
output = previewer(latents)
|
52 |
+
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
|
53 |
+
return output
|
54 |
|
55 |
else:
|
56 |
previewer = None
|
|
|
96 |
callback=callback_prior,
|
97 |
)
|
98 |
|
99 |
+
if PREVIEW_IMAGES:
|
100 |
+
for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
|
101 |
+
r = next(prior_output)
|
102 |
+
if isinstance(r, list):
|
103 |
+
yield r[0]
|
104 |
+
prior_output = r
|
105 |
|
106 |
decoder_output = decoder_pipeline(
|
107 |
image_embeddings=prior_output.image_embeddings,
|