ernie-m-base_pytorch / pytorch_weights_postprocess.py
susnato's picture
Added convertion files and README
f3581c2
raw
history blame
2.56 kB
# This code takes the pytorch weights generated using paddle2torch_weights script and then stacks
# Queries, Keys and Values for Attention(self_attn) Layer in Encoder Layers(to make it more like torch.nn.MultiheadAttention).
import torch
full_state_dict = torch.load("./pytorch_model.bin")
full_state_dict = dict((".".join(k.split(".")[1:]), v) \
for k, v in full_state_dict.items())
def con_cat(kqv_dict):
kqv_dict_keys = list(kqv_dict.keys())
if "weight" in kqv_dict_keys[0]:
tmp = kqv_dict_keys[0].split(".")[3]
c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")],
kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")],
kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")]
])
c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_weight"])
# return {c_dict_key:c_dict_value}
return {f"encoder.{c_dict_key}":c_dict_value}
#(k,q,v), (k,v,q), (q, k, v), (q, v, k), (v, k, q), (v, q, k)
if "bias" in kqv_dict_keys[0]:
tmp = kqv_dict_keys[0].split(".")[3]
c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")],
kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")],
kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")]
])
c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_bias"])
# return {c_dict_key:c_dict_value}
return {f"encoder.{c_dict_key}":c_dict_value}
mod_dict = {}
#Embedding weights
for k, v in full_state_dict.items():
if "embedding" in k or "layer_norm" in k:
mod_dict.update({f"embeddings.{k}": v})
#Encoder weights
for i in range(12):
sd = dict((k, v) for k, v in full_state_dict.items() if f"layers.{i}" in k)
kvq_weight = {}
kvq_bias = {}
for k, v in sd.items():
if "self_attn" in k and "out_proj" not in k:
if "weight" in k:
kvq_weight[k] = v
if "bias" in k:
kvq_bias[k] = v
else:
mod_dict[f"encoder.{k}"] = v
mod_dict.update(con_cat(kvq_weight))
mod_dict.update(con_cat(kvq_bias))
#Pooler
for k, v in full_state_dict.items():
if "pooler" in k:
mod_dict.update({k:v})
for k, v in mod_dict.items():
print(k, v.size())
model_name = "ernie-m-base_pytorch"
PATH = f"./{model_name}/pytorch_model.bin"
torch.save(mod_dict, PATH)