Spaces:
Runtime error
Runtime error
cwkuo
commited on
Commit
·
febd802
1
Parent(s):
7962ed0
trim checkpoint model weights
Browse files- app.py +2 -6
- model/utils.py +18 -0
app.py
CHANGED
@@ -370,13 +370,9 @@ def build_model():
|
|
370 |
_, image_trans = get_gptk_image_transform()
|
371 |
topk = {"whole": 60, "five": 24, "nine": 16}
|
372 |
gptk_model = get_gptk_model(d_knwl=d_knwl, topk=topk)
|
373 |
-
gptk_ckpt = "model/ckpt/
|
374 |
gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
|
375 |
-
gptk_ckpt =
|
376 |
-
".".join(k.split(".")[2:]): v
|
377 |
-
for k, v in gptk_ckpt["module"].items()
|
378 |
-
}
|
379 |
-
gptk_model.load_state_dict(gptk_ckpt)
|
380 |
gptk_model = gptk_model.to(device).eval()
|
381 |
|
382 |
return knwl_db, query_enc, query_trans, gptk_model, image_trans, topk, device
|
|
|
370 |
_, image_trans = get_gptk_image_transform()
|
371 |
topk = {"whole": 60, "five": 24, "nine": 16}
|
372 |
gptk_model = get_gptk_model(d_knwl=d_knwl, topk=topk)
|
373 |
+
gptk_ckpt = "model/ckpt/gptk-vicuna7b.pt"
|
374 |
gptk_ckpt = torch.load(gptk_ckpt, map_location="cpu")
|
375 |
+
gptk_model.load_state_dict(gptk_ckpt, strict=False)
|
|
|
|
|
|
|
|
|
376 |
gptk_model = gptk_model.to(device).eval()
|
377 |
|
378 |
return knwl_db, query_enc, query_trans, gptk_model, image_trans, topk, device
|
model/utils.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import torch
|
3 |
import torch.distributed as dist
|
4 |
import timm.models.hub as timm_hub
|
5 |
|
6 |
|
|
|
7 |
def drop_sequence_mask(N, S, device, p=0.1, training=True):
|
8 |
if training:
|
9 |
mask = torch.rand((N, S), device=device)
|
@@ -77,3 +80,18 @@ def download_cached_file(url, check_hash=True, progress=False):
|
|
77 |
dist.barrier()
|
78 |
|
79 |
return get_cached_file_path()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from pprint import pprint
|
3 |
+
from tqdm import tqdm
|
4 |
import torch
|
5 |
import torch.distributed as dist
|
6 |
import timm.models.hub as timm_hub
|
7 |
|
8 |
|
9 |
+
|
10 |
def drop_sequence_mask(N, S, device, p=0.1, training=True):
|
11 |
if training:
|
12 |
mask = torch.rand((N, S), device=device)
|
|
|
80 |
dist.barrier()
|
81 |
|
82 |
return get_cached_file_path()
|
83 |
+
|
84 |
+
|
85 |
+
def trim_ckpt(ckpt_input, ckpt_output, extra_keys=()):
|
86 |
+
kept_keys = ('llm_proj', 'knwl', 'qformer', 'ln_vision', 'query_tokens') + extra_keys
|
87 |
+
|
88 |
+
ckpt = torch.load(ckpt_input, map_location="cpu")
|
89 |
+
ckpt = {
|
90 |
+
".".join(n.split(".")[2:]): v
|
91 |
+
for n, v in tqdm(ckpt["module"].items(), dynamic_ncols=True)
|
92 |
+
if any([k in n for k in kept_keys])
|
93 |
+
}
|
94 |
+
print("Kept params:")
|
95 |
+
pprint(list(ckpt.keys()))
|
96 |
+
|
97 |
+
torch.save(ckpt, ckpt_output)
|