PommesPeter commited on
Commit
eab6e9f
1 Parent(s): c13cbf5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -27
app.py CHANGED
@@ -194,7 +194,7 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
194
  solver,
195
  t_shift,
196
  seed,
197
- ntk_scaling,
198
  proportional_attn,
199
  ) = infer_args
200
 
@@ -207,7 +207,7 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
207
  solver=solver,
208
  t_shift=t_shift,
209
  seed=seed,
210
- ntk_scaling=ntk_scaling,
211
  proportional_attn=proportional_attn,
212
  )
213
  print("> params:", json.dumps(metadata, indent=2))
@@ -244,6 +244,19 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
244
  resolution = resolution.split(" ")[-1]
245
  w, h = resolution.split("x")
246
  w, h = int(w), int(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  latent_w, latent_h = w // 8, h // 8
248
  if int(seed) != 0:
249
  torch.random.manual_seed(int(seed))
@@ -251,37 +264,27 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
251
  z = z.repeat(2, 1, 1, 1)
252
 
253
  with torch.no_grad():
254
- cap_feats, cap_mask = encode_prompt(
255
- [cap] + [""], text_encoder, tokenizer, 0.0
256
- )
257
  if neg_cap != "":
258
- neg_cap_feats, neg_cap_mask = encode_prompt(
259
- [neg_cap] + [""],
 
 
 
 
 
 
 
260
  text_encoder,
261
  tokenizer,
262
  0.0,
263
  )
264
- cap_feats = torch.cat([neg_cap_feats, cap_feats], dim=1)
265
- cap_mask = torch.cat([neg_cap_mask, cap_mask], dim=1)
266
-
267
  cap_mask = cap_mask.to(cap_feats.device)
268
 
269
- train_res = 1024
270
- res_cat = (w * h) ** 0.5
271
- print(f"res_cat: {res_cat}")
272
- max_seq_len = (res_cat // 16) ** 2 + (res_cat // 16) * 2
273
- print(f"max_seq_len: {max_seq_len}")
274
-
275
- rope_scaling_factor = 1.0
276
- ntk_factor = max_seq_len / (train_res // 16) ** 2
277
- print(f"ntk_factor: {ntk_factor}")
278
-
279
  model_kwargs = dict(
280
  cap_feats=cap_feats,
281
  cap_mask=cap_mask,
282
  cfg_scale=cfg_scale,
283
- rope_scaling_factor=rope_scaling_factor,
284
- ntk_factor=ntk_factor,
285
  )
286
 
287
  print("> start sample")
@@ -504,10 +507,10 @@ def main():
504
  label="CFG scale",
505
  )
506
  with gr.Row():
507
- ntk_scaling = gr.Checkbox(
508
- value=True,
509
- interactive=True,
510
- label="ntk scaling",
511
  )
512
  proportional_attn = gr.Checkbox(
513
  value=True,
@@ -608,7 +611,7 @@ def main():
608
  solver,
609
  t_shift,
610
  seed,
611
- ntk_scaling,
612
  proportional_attn,
613
  ],
614
  [output_img, gr_metadata],
 
194
  solver,
195
  t_shift,
196
  seed,
197
+ scale_method,
198
  proportional_attn,
199
  ) = infer_args
200
 
 
207
  solver=solver,
208
  t_shift=t_shift,
209
  seed=seed,
210
+ ntk_scaling=scale_method,
211
  proportional_attn=proportional_attn,
212
  )
213
  print("> params:", json.dumps(metadata, indent=2))
 
244
  resolution = resolution.split(" ")[-1]
245
  w, h = resolution.split("x")
246
  w, h = int(w), int(h)
247
+
248
+ res_cat = (w * h) ** 0.5
249
+ seq_len = res_cat // 16
250
+
251
+ scaling_method = "ntk"
252
+ train_seq_len = 64
253
+ if scaling_method == "ntk":
254
+ scale_factor = seq_len / train_seq_len
255
+ else:
256
+ raise NotImplementedError
257
+
258
+ print(f"> scale factor: {scale_factor}")
259
+
260
  latent_w, latent_h = w // 8, h // 8
261
  if int(seed) != 0:
262
  torch.random.manual_seed(int(seed))
 
264
  z = z.repeat(2, 1, 1, 1)
265
 
266
  with torch.no_grad():
 
 
 
267
  if neg_cap != "":
268
+ cap_feats, cap_mask = encode_prompt(
269
+ [cap] + [neg_cap],
270
+ text_encoder,
271
+ tokenizer,
272
+ 0.0,
273
+ )
274
+ else:
275
+ cap_feats, cap_mask = encode_prompt(
276
+ [cap] + [""],
277
  text_encoder,
278
  tokenizer,
279
  0.0,
280
  )
 
 
 
281
  cap_mask = cap_mask.to(cap_feats.device)
282
 
 
 
 
 
 
 
 
 
 
 
283
  model_kwargs = dict(
284
  cap_feats=cap_feats,
285
  cap_mask=cap_mask,
286
  cfg_scale=cfg_scale,
287
+ scale_factor=scale_factor,
 
288
  )
289
 
290
  print("> start sample")
 
507
  label="CFG scale",
508
  )
509
  with gr.Row():
510
+ scale_methods = gr.Dropdown(
511
+ value="ntk",
512
+ choices=["ntk"],
513
+ label="Scale methods",
514
  )
515
  proportional_attn = gr.Checkbox(
516
  value=True,
 
611
  solver,
612
  t_shift,
613
  seed,
614
+ scale_methods,
615
  proportional_attn,
616
  ],
617
  [output_img, gr_metadata],