Spaces:
Sleeping
Sleeping
def load_weight(model, state_dict): | |
old_keys = [] | |
new_keys = [] | |
for key in state_dict.keys(): | |
new_key = None | |
if key.endswith(".g"): | |
new_key = key[:-2] + ".weight" | |
elif key.endswith(".b"): | |
new_key = key[:-2] + ".bias" | |
elif key.endswith(".w"): | |
new_key = key[:-2] + ".weight" | |
if new_key: | |
old_keys.append(key) | |
new_keys.append(new_key) | |
for old_key, new_key in zip(old_keys, new_keys): | |
state_dict[new_key] = state_dict.pop(old_key) | |
missing_keys = [] | |
unexpected_keys = [] | |
error_msgs = [] | |
# copy state_dict so _load_from_state_dict can modify it | |
metadata = getattr(state_dict, "_metadata", None) | |
state_dict = state_dict.copy() | |
if metadata is not None: | |
state_dict._metadata = metadata | |
def load(module, prefix=""): | |
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
module._load_from_state_dict( | |
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs | |
) | |
for name, child in module._modules.items(): | |
if child is not None: | |
load(child, prefix + name + ".") | |
start_model = model | |
if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()): | |
start_model = model.transformer | |
load(start_model, prefix="") | |
# Make sure we are still sharing the output and input embeddings after loading weights | |
model.set_tied() | |
return model |