levmckinney commited on
Commit
71df1af
1 Parent(s): 62c1cda

reduced atol to migrate pythia 1.4b deduped v0

Browse files
Files changed (1) hide show
  1. lens_migration.py +1 -1
lens_migration.py CHANGED
@@ -379,6 +379,6 @@ if __name__ == "__main__":
379
  log_ps_new = logits_new.log_softmax(-1)
380
  log_ps_old = logits_old.log_softmax(-1)
381
  print("js div", js_divergence(log_ps_new, log_ps_old))
382
- assert (th.allclose(log_ps_new, log_ps_old, atol=1e-7)), (log_ps_new - log_ps_old).abs().max()
383
  print("Saving new lens to", args.output_dir)
384
  tuned_lens.to(th.device("cpu")).save(args.output_dir)
 
379
  log_ps_new = logits_new.log_softmax(-1)
380
  log_ps_old = logits_old.log_softmax(-1)
381
  print("js div", js_divergence(log_ps_new, log_ps_old))
382
+ assert (th.allclose(log_ps_new, log_ps_old, atol=1e-4)), (log_ps_new - log_ps_old).abs().max()
383
  print("Saving new lens to", args.output_dir)
384
  tuned_lens.to(th.device("cpu")).save(args.output_dir)