PommesPeter commited on
Commit
3c34a7f
1 Parent(s): f18c66a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -241,21 +241,11 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
241
  )
242
  # end sampler
243
 
 
244
  resolution = resolution.split(" ")[-1]
245
  w, h = resolution.split("x")
246
  w, h = int(w), int(h)
247
 
248
- res_cat = (w * h) ** 0.5
249
- seq_len = res_cat // 16
250
-
251
- scaling_method = "ntk"
252
- train_seq_len = 64
253
- if scaling_method == "ntk":
254
- scale_factor = seq_len / train_seq_len
255
- else:
256
- raise NotImplementedError
257
-
258
- print(f"> scale factor: {scale_factor}")
259
 
260
  latent_w, latent_h = w // 8, h // 8
261
  if int(seed) != 0:
@@ -284,9 +274,18 @@ def infer_ode(args, infer_args, text_encoder, tokenizer, vae, model):
284
  cap_feats=cap_feats,
285
  cap_mask=cap_mask,
286
  cfg_scale=cfg_scale,
287
- scale_factor=scale_factor,
288
  )
289
 
 
 
 
 
 
 
 
 
 
 
290
  print("> start sample")
291
  samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
292
  samples = samples[:1]
@@ -511,9 +510,9 @@ def main():
511
  )
512
  with gr.Row():
513
  scale_methods = gr.Dropdown(
514
- value="ntk",
515
- choices=["ntk"],
516
- label="Scale methods",
517
  )
518
  proportional_attn = gr.Checkbox(
519
  value=True,
 
241
  )
242
  # end sampler
243
 
244
+ do_extrapolation = "Extrapolation" in resolution
245
  resolution = resolution.split(" ")[-1]
246
  w, h = resolution.split("x")
247
  w, h = int(w), int(h)
248
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  latent_w, latent_h = w // 8, h // 8
251
  if int(seed) != 0:
 
274
  cap_feats=cap_feats,
275
  cap_mask=cap_mask,
276
  cfg_scale=cfg_scale,
 
277
  )
278
 
279
+ if proportional_attn:
280
+ model_kwargs["proportional_attn"] = True
281
+ model_kwargs["base_seqlen"] = (train_args.image_size // 16) ** 2
282
+ if do_extrapolation and scaling_method == "Time-aware":
283
+ model_kwargs["scale_factor"] = math.sqrt(w * h / train_args.image_size ** 2)
284
+ else:
285
+ model_kwargs["scale_factor"] = 1.0
286
+
287
+ print(f"> scale factor: {model_kwargs["scale_factor"]}")
288
+
289
  print("> start sample")
290
  samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
291
  samples = samples[:1]
 
510
  )
511
  with gr.Row():
512
  scale_methods = gr.Dropdown(
513
+ value="Time-aware",
514
+ choices=["Time-aware", "None"],
515
+ label="Rope scaling method",
516
  )
517
  proportional_attn = gr.Checkbox(
518
  value=True,