Spaces:
Paused
Paused
Create model_loading.py
Browse files- utils/model_loading.py +42 -0
utils/model_loading.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
3 |
+
# Copyright (c) 2022 Microsoft
|
4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
5 |
+
# Written by Xueyan Zou (xueyan@cs.wisc.edu)
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from utils.distributed import is_main_process
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def align_and_update_state_dicts(model_state_dict, ckpt_state_dict):
|
14 |
+
model_keys = sorted(model_state_dict.keys())
|
15 |
+
ckpt_keys = sorted(ckpt_state_dict.keys())
|
16 |
+
result_dicts = {}
|
17 |
+
matched_log = []
|
18 |
+
unmatched_log = []
|
19 |
+
unloaded_log = []
|
20 |
+
for model_key in model_keys:
|
21 |
+
model_weight = model_state_dict[model_key]
|
22 |
+
if model_key in ckpt_keys:
|
23 |
+
ckpt_weight = ckpt_state_dict[model_key]
|
24 |
+
if model_weight.shape == ckpt_weight.shape:
|
25 |
+
result_dicts[model_key] = ckpt_weight
|
26 |
+
ckpt_keys.pop(ckpt_keys.index(model_key))
|
27 |
+
matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
|
28 |
+
else:
|
29 |
+
unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
|
30 |
+
else:
|
31 |
+
unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape))
|
32 |
+
|
33 |
+
if is_main_process():
|
34 |
+
for info in matched_log:
|
35 |
+
logger.info(info)
|
36 |
+
for info in unloaded_log:
|
37 |
+
logger.warning(info)
|
38 |
+
for key in ckpt_keys:
|
39 |
+
logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape))
|
40 |
+
for info in unmatched_log:
|
41 |
+
logger.warning(info)
|
42 |
+
return result_dicts
|