Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -194,7 +194,7 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
194 |
solver,
|
195 |
t_shift,
|
196 |
seed,
|
197 |
-
|
198 |
proportional_attn,
|
199 |
) = infer_args
|
200 |
|
@@ -207,7 +207,7 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
207 |
solver=solver,
|
208 |
t_shift=t_shift,
|
209 |
seed=seed,
|
210 |
-
ntk_scaling=
|
211 |
proportional_attn=proportional_attn,
|
212 |
)
|
213 |
print("> params:", json.dumps(metadata, indent=2))
|
@@ -244,6 +244,19 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
244 |
resolution = resolution.split(" ")[-1]
|
245 |
w, h = resolution.split("x")
|
246 |
w, h = int(w), int(h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
latent_w, latent_h = w // 8, h // 8
|
248 |
if int(seed) != 0:
|
249 |
torch.random.manual_seed(int(seed))
|
@@ -251,37 +264,27 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
|
|
251 |
z = z.repeat(2, 1, 1, 1)
|
252 |
|
253 |
with torch.no_grad():
|
254 |
-
cap_feats, cap_mask = encode_prompt(
|
255 |
-
[cap] + [""], text_encoder, tokenizer, 0.0
|
256 |
-
)
|
257 |
if neg_cap != "":
|
258 |
-
|
259 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
text_encoder,
|
261 |
tokenizer,
|
262 |
0.0,
|
263 |
)
|
264 |
-
cap_feats = torch.cat([neg_cap_feats, cap_feats], dim=1)
|
265 |
-
cap_mask = torch.cat([neg_cap_mask, cap_mask], dim=1)
|
266 |
-
|
267 |
cap_mask = cap_mask.to(cap_feats.device)
|
268 |
|
269 |
-
train_res = 1024
|
270 |
-
res_cat = (w * h) ** 0.5
|
271 |
-
print(f"res_cat: {res_cat}")
|
272 |
-
max_seq_len = (res_cat // 16) ** 2 + (res_cat // 16) * 2
|
273 |
-
print(f"max_seq_len: {max_seq_len}")
|
274 |
-
|
275 |
-
rope_scaling_factor = 1.0
|
276 |
-
ntk_factor = max_seq_len / (train_res // 16) ** 2
|
277 |
-
print(f"ntk_factor: {ntk_factor}")
|
278 |
-
|
279 |
model_kwargs = dict(
|
280 |
cap_feats=cap_feats,
|
281 |
cap_mask=cap_mask,
|
282 |
cfg_scale=cfg_scale,
|
283 |
-
|
284 |
-
ntk_factor=ntk_factor,
|
285 |
)
|
286 |
|
287 |
print("> start sample")
|
@@ -504,10 +507,10 @@ def main():
|
|
504 |
label="CFG scale",
|
505 |
)
|
506 |
with gr.Row():
|
507 |
-
|
508 |
-
value=
|
509 |
-
|
510 |
-
label="
|
511 |
)
|
512 |
proportional_attn = gr.Checkbox(
|
513 |
value=True,
|
@@ -608,7 +611,7 @@ def main():
|
|
608 |
solver,
|
609 |
t_shift,
|
610 |
seed,
|
611 |
-
|
612 |
proportional_attn,
|
613 |
],
|
614 |
[output_img, gr_metadata],
|
|
|
194 |
solver,
|
195 |
t_shift,
|
196 |
seed,
|
197 |
+
scale_method,
|
198 |
proportional_attn,
|
199 |
) = infer_args
|
200 |
|
|
|
207 |
solver=solver,
|
208 |
t_shift=t_shift,
|
209 |
seed=seed,
|
210 |
+
ntk_scaling=scale_method,
|
211 |
proportional_attn=proportional_attn,
|
212 |
)
|
213 |
print("> params:", json.dumps(metadata, indent=2))
|
|
|
244 |
resolution = resolution.split(" ")[-1]
|
245 |
w, h = resolution.split("x")
|
246 |
w, h = int(w), int(h)
|
247 |
+
|
248 |
+
res_cat = (w * h) ** 0.5
|
249 |
+
seq_len = res_cat // 16
|
250 |
+
|
251 |
+
scaling_method = "ntk"
|
252 |
+
train_seq_len = 64
|
253 |
+
if scaling_method == "ntk":
|
254 |
+
scale_factor = seq_len / train_seq_len
|
255 |
+
else:
|
256 |
+
raise NotImplementedError
|
257 |
+
|
258 |
+
print(f"> scale factor: {scale_factor}")
|
259 |
+
|
260 |
latent_w, latent_h = w // 8, h // 8
|
261 |
if int(seed) != 0:
|
262 |
torch.random.manual_seed(int(seed))
|
|
|
264 |
z = z.repeat(2, 1, 1, 1)
|
265 |
|
266 |
with torch.no_grad():
|
|
|
|
|
|
|
267 |
if neg_cap != "":
|
268 |
+
cap_feats, cap_mask = encode_prompt(
|
269 |
+
[cap] + [neg_cap],
|
270 |
+
text_encoder,
|
271 |
+
tokenizer,
|
272 |
+
0.0,
|
273 |
+
)
|
274 |
+
else:
|
275 |
+
cap_feats, cap_mask = encode_prompt(
|
276 |
+
[cap] + [""],
|
277 |
text_encoder,
|
278 |
tokenizer,
|
279 |
0.0,
|
280 |
)
|
|
|
|
|
|
|
281 |
cap_mask = cap_mask.to(cap_feats.device)
|
282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
model_kwargs = dict(
|
284 |
cap_feats=cap_feats,
|
285 |
cap_mask=cap_mask,
|
286 |
cfg_scale=cfg_scale,
|
287 |
+
scale_factor=scale_factor,
|
|
|
288 |
)
|
289 |
|
290 |
print("> start sample")
|
|
|
507 |
label="CFG scale",
|
508 |
)
|
509 |
with gr.Row():
|
510 |
+
scale_methods = gr.Dropdown(
|
511 |
+
value="ntk",
|
512 |
+
choices=["ntk"],
|
513 |
+
label="Scale methods",
|
514 |
)
|
515 |
proportional_attn = gr.Checkbox(
|
516 |
value=True,
|
|
|
611 |
solver,
|
612 |
t_shift,
|
613 |
seed,
|
614 |
+
scale_methods,
|
615 |
proportional_attn,
|
616 |
],
|
617 |
[output_img, gr_metadata],
|