Gbssreejith commited on
Commit
8ebc1a7
1 Parent(s): 5d0987f

Upload load_weights.py

Browse files
Files changed (1) hide show
  1. load_weights.py +44 -0
load_weights.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def load_weight(model, state_dict):
3
+ old_keys = []
4
+ new_keys = []
5
+ for key in state_dict.keys():
6
+ new_key = None
7
+ if key.endswith(".g"):
8
+ new_key = key[:-2] + ".weight"
9
+ elif key.endswith(".b"):
10
+ new_key = key[:-2] + ".bias"
11
+ elif key.endswith(".w"):
12
+ new_key = key[:-2] + ".weight"
13
+ if new_key:
14
+ old_keys.append(key)
15
+ new_keys.append(new_key)
16
+ for old_key, new_key in zip(old_keys, new_keys):
17
+ state_dict[new_key] = state_dict.pop(old_key)
18
+
19
+ missing_keys = []
20
+ unexpected_keys = []
21
+ error_msgs = []
22
+ # copy state_dict so _load_from_state_dict can modify it
23
+ metadata = getattr(state_dict, "_metadata", None)
24
+ state_dict = state_dict.copy()
25
+ if metadata is not None:
26
+ state_dict._metadata = metadata
27
+
28
+ def load(module, prefix=""):
29
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
30
+ module._load_from_state_dict(
31
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
32
+ )
33
+ for name, child in module._modules.items():
34
+ if child is not None:
35
+ load(child, prefix + name + ".")
36
+
37
+ start_model = model
38
+ if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
39
+ start_model = model.transformer
40
+ load(start_model, prefix="")
41
+
42
+ # Make sure we are still sharing the output and input embeddings after loading weights
43
+ model.set_tied()
44
+ return model