fffiloni commited on
Commit
1469496
1 Parent(s): c381dd3

Create model_loading.py

Browse files
Files changed (1) hide show
  1. utils/model_loading.py +42 -0
utils/model_loading.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # X-Decoder -- Generalized Decoding for Pixel, Image, and Language
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Xueyan Zou (xueyan@cs.wisc.edu)
6
+ # --------------------------------------------------------
7
+
8
+ import logging
9
+ from utils.distributed import is_main_process
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def align_and_update_state_dicts(model_state_dict, ckpt_state_dict):
14
+ model_keys = sorted(model_state_dict.keys())
15
+ ckpt_keys = sorted(ckpt_state_dict.keys())
16
+ result_dicts = {}
17
+ matched_log = []
18
+ unmatched_log = []
19
+ unloaded_log = []
20
+ for model_key in model_keys:
21
+ model_weight = model_state_dict[model_key]
22
+ if model_key in ckpt_keys:
23
+ ckpt_weight = ckpt_state_dict[model_key]
24
+ if model_weight.shape == ckpt_weight.shape:
25
+ result_dicts[model_key] = ckpt_weight
26
+ ckpt_keys.pop(ckpt_keys.index(model_key))
27
+ matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
28
+ else:
29
+ unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape))
30
+ else:
31
+ unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape))
32
+
33
+ if is_main_process():
34
+ for info in matched_log:
35
+ logger.info(info)
36
+ for info in unloaded_log:
37
+ logger.warning(info)
38
+ for key in ckpt_keys:
39
+ logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape))
40
+ for info in unmatched_log:
41
+ logger.warning(info)
42
+ return result_dicts