Lev McKinney commited on
Commit
b9dc122
·
1 Parent(s): dba1d6e

added more logging to migration process

Browse files
Files changed (1) hide show
  1. lens_migration.py +6 -3
lens_migration.py CHANGED
@@ -5,6 +5,7 @@ from copy import deepcopy
5
  import inspect
6
  from logging import warn
7
  from pathlib import Path
 
8
  import json
9
 
10
  from tuned_lens.model_surgery import get_final_norm, get_transformer_layers
@@ -352,13 +353,15 @@ if __name__ == "__main__":
352
 
353
  device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
354
 
 
355
  tuned_lens_old = TunedLensOld.load(args.resource_id, map_location=device)
356
 
 
357
  tuned_lens = TunedLens.from_model(
358
  model, bias=tuned_lens_old.config['bias'], revision=revision
359
  )
360
 
361
- for i in range(len(tuned_lens_old)):
362
  tuned_lens[i].load_state_dict(tuned_lens_old[i].state_dict())
363
 
364
 
@@ -368,7 +371,7 @@ if __name__ == "__main__":
368
 
369
  # Fuzz the new lens against the old one's
370
  with th.no_grad():
371
- for i in range(len(tuned_lens)):
372
  for _ in range(10):
373
  a = th.randn(1, 1, tuned_lens.config.d_model, device=device)
374
  logits_new = tuned_lens(a, i)
@@ -377,5 +380,5 @@ if __name__ == "__main__":
377
  log_ps_old = logits_old.log_softmax(-1)
378
  assert (th.allclose(log_ps_new, log_ps_old))
379
  print("js div", js_divergence(log_ps_new, log_ps_old))
380
-
381
  tuned_lens.to(th.device("cpu")).save(args.output_dir)
 
5
  import inspect
6
  from logging import warn
7
  from pathlib import Path
8
+ import tqdm
9
  import json
10
 
11
  from tuned_lens.model_surgery import get_final_norm, get_transformer_layers
 
353
 
354
  device = th.device("cuda:0" if th.cuda.is_available() else "cpu")
355
 
356
+ print("Loading old lens")
357
  tuned_lens_old = TunedLensOld.load(args.resource_id, map_location=device)
358
 
359
+ print("Initializing new lens")
360
  tuned_lens = TunedLens.from_model(
361
  model, bias=tuned_lens_old.config['bias'], revision=revision
362
  )
363
 
364
+ for i in tqdm(range(len(tuned_lens_old)), desc="Copying parameters"):
365
  tuned_lens[i].load_state_dict(tuned_lens_old[i].state_dict())
366
 
367
 
 
371
 
372
  # Fuzz the new lens against the old one's
373
  with th.no_grad():
374
+ for i in tqdm(range(len(tuned_lens)), desc="Fuzzing layers"):
375
  for _ in range(10):
376
  a = th.randn(1, 1, tuned_lens.config.d_model, device=device)
377
  logits_new = tuned_lens(a, i)
 
380
  log_ps_old = logits_old.log_softmax(-1)
381
  assert (th.allclose(log_ps_new, log_ps_old))
382
  print("js div", js_divergence(log_ps_new, log_ps_old))
383
+ print("Saving new lens to", args.output_dir)
384
  tuned_lens.to(th.device("cpu")).save(args.output_dir)