Spaces:
Sleeping
Sleeping
Gbssreejith
commited on
Commit
•
8ebc1a7
1
Parent(s):
5d0987f
Upload load_weights.py
Browse files- load_weights.py +44 -0
load_weights.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def load_weight(model, state_dict):
|
3 |
+
old_keys = []
|
4 |
+
new_keys = []
|
5 |
+
for key in state_dict.keys():
|
6 |
+
new_key = None
|
7 |
+
if key.endswith(".g"):
|
8 |
+
new_key = key[:-2] + ".weight"
|
9 |
+
elif key.endswith(".b"):
|
10 |
+
new_key = key[:-2] + ".bias"
|
11 |
+
elif key.endswith(".w"):
|
12 |
+
new_key = key[:-2] + ".weight"
|
13 |
+
if new_key:
|
14 |
+
old_keys.append(key)
|
15 |
+
new_keys.append(new_key)
|
16 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
17 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
18 |
+
|
19 |
+
missing_keys = []
|
20 |
+
unexpected_keys = []
|
21 |
+
error_msgs = []
|
22 |
+
# copy state_dict so _load_from_state_dict can modify it
|
23 |
+
metadata = getattr(state_dict, "_metadata", None)
|
24 |
+
state_dict = state_dict.copy()
|
25 |
+
if metadata is not None:
|
26 |
+
state_dict._metadata = metadata
|
27 |
+
|
28 |
+
def load(module, prefix=""):
|
29 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
30 |
+
module._load_from_state_dict(
|
31 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
32 |
+
)
|
33 |
+
for name, child in module._modules.items():
|
34 |
+
if child is not None:
|
35 |
+
load(child, prefix + name + ".")
|
36 |
+
|
37 |
+
start_model = model
|
38 |
+
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
|
39 |
+
start_model = model.transformer
|
40 |
+
load(start_model, prefix="")
|
41 |
+
|
42 |
+
# Make sure we are still sharing the output and input embeddings after loading weights
|
43 |
+
model.set_tied()
|
44 |
+
return model
|