# This code takes the pytorch weights generated using paddle2torch_weights script and then stacks # Queries, Keys and Values for Attention(self_attn) Layer in Encoder Layers(to make it more like torch.nn.MultiheadAttention). import torch full_state_dict = torch.load("./pytorch_model.bin") full_state_dict = dict((".".join(k.split(".")[1:]), v) \ for k, v in full_state_dict.items()) def con_cat(kqv_dict): kqv_dict_keys = list(kqv_dict.keys()) if "weight" in kqv_dict_keys[0]: tmp = kqv_dict_keys[0].split(".")[3] c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")], kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")], kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")] ]) c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_weight"]) # return {c_dict_key:c_dict_value} return {f"encoder.{c_dict_key}":c_dict_value} #(k,q,v), (k,v,q), (q, k, v), (q, v, k), (v, k, q), (v, q, k) if "bias" in kqv_dict_keys[0]: tmp = kqv_dict_keys[0].split(".")[3] c_dict_value = torch.cat([kqv_dict[kqv_dict_keys[0].replace(tmp, "q_proj")], kqv_dict[kqv_dict_keys[0].replace(tmp, "k_proj")], kqv_dict[kqv_dict_keys[0].replace(tmp, "v_proj")] ]) c_dict_key = ".".join(kqv_dict_keys[0].split(".")[:3]+["in_proj_bias"]) # return {c_dict_key:c_dict_value} return {f"encoder.{c_dict_key}":c_dict_value} mod_dict = {} #Embedding weights for k, v in full_state_dict.items(): if "embedding" in k or "layer_norm" in k: mod_dict.update({f"embeddings.{k}": v}) #Encoder weights for i in range(24): sd = dict((k, v) for k, v in full_state_dict.items() if f"layers.{i}" in k) kvq_weight = {} kvq_bias = {} for k, v in sd.items(): if "self_attn" in k and "out_proj" not in k: if "weight" in k: kvq_weight[k] = v if "bias" in k: kvq_bias[k] = v else: mod_dict[f"encoder.{k}"] = v mod_dict.update(con_cat(kvq_weight)) mod_dict.update(con_cat(kvq_bias)) #Pooler for k, v in full_state_dict.items(): if "dense" in k: mod_dict.update({f"pooler.{k}":v}) for k, v in mod_dict.items(): print(k, v.size()) model_name = "ernie-m-base_pytorch" PATH = f"./{model_name}/pytorch_model.bin" torch.save(mod_dict, PATH)