cwkuo commited on
Commit
febd802
·
1 Parent(s): 7962ed0

trim checkpoint model weights

Browse files
Files changed (2) hide show
  1. app.py +2 -6
  2. model/utils.py +18 -0
app.py CHANGED
@@ -370,13 +370,9 @@ def build_model():
370
  _, image_trans = get_gptk_image_transform()
371
  topk = {"whole": 60, "five": 24, "nine": 16}
372
  gptk_model = get_gptk_model(d_knwl=d_knwl, topk=topk)
373
- gptk_ckpt = "model/ckpt/mp_rank_00_model_states.pt"
374
  gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
375
- gptk_ckpt = {
376
- ".".join(k.split(".")[2:]): v
377
- for k, v in gptk_ckpt["module"].items()
378
- }
379
- gptk_model.load_state_dict(gptk_ckpt)
380
  gptk_model = gptk_model.to(device).eval()
381
 
382
  return knwl_db, query_enc, query_trans, gptk_model, image_trans, topk, device
 
370
  _, image_trans = get_gptk_image_transform()
371
  topk = {"whole": 60, "five": 24, "nine": 16}
372
  gptk_model = get_gptk_model(d_knwl=d_knwl, topk=topk)
373
+ gptk_ckpt = "model/ckpt/gptk-vicuna7b.pt"
374
  gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
375
+ gptk_model.load_state_dict(gptk_ckpt, strict=False)
 
 
 
 
376
  gptk_model = gptk_model.to(device).eval()
377
 
378
  return knwl_db, query_enc, query_trans, gptk_model, image_trans, topk, device
model/utils.py CHANGED
@@ -1,9 +1,12 @@
1
  import os
 
 
2
  import torch
3
  import torch.distributed as dist
4
  import timm.models.hub as timm_hub
5
 
6
 
 
7
  def drop_sequence_mask(N, S, device, p=0.1, training=True):
8
  if training:
9
  mask = torch.rand((N, S), device=device)
@@ -77,3 +80,18 @@ def download_cached_file(url, check_hash=True, progress=False):
77
  dist.barrier()
78
 
79
  return get_cached_file_path()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from pprint import pprint
3
+ from tqdm import tqdm
4
  import torch
5
  import torch.distributed as dist
6
  import timm.models.hub as timm_hub
7
 
8
 
9
+
10
  def drop_sequence_mask(N, S, device, p=0.1, training=True):
11
  if training:
12
  mask = torch.rand((N, S), device=device)
 
80
  dist.barrier()
81
 
82
  return get_cached_file_path()
83
+
84
+
85
+ def trim_ckpt(ckpt_input, ckpt_output, extra_keys=()):
86
+ kept_keys = ('llm_proj', 'knwl', 'qformer', 'ln_vision', 'query_tokens') + extra_keys
87
+
88
+ ckpt = torch.load(ckpt_input, map_location="cpu")
89
+ ckpt = {
90
+ ".".join(n.split(".")[2:]): v
91
+ for n, v in tqdm(ckpt["module"].items(), dynamic_ncols=True)
92
+ if any([k in n for k in kept_keys])
93
+ }
94
+ print("Kept params:")
95
+ pprint(list(ckpt.keys()))
96
+
97
+ torch.save(ckpt, ckpt_output)