关于模型的训练细节
#2
by
zrqssjdys
- opened
首先十分感谢作者的工作,非常优秀的模型,我最近在进行训练和使用时出现了两个问题,想要请教一下:
1.模型训练时,是只训练新添加的全连接层还是将整个模型和全连接层一起训练?
2.在保存模型时,怎样才能保存成作者您分享出来的格式(即模型和全连接层分开保存,我自己保存的模型无法正常被sentence-transformers导入),能告知一下作者您保存模型时使用的包和方法?
1:无论是训练初始模型还是后续的dialogue-embedding模型或者后续的MRL版本,都是全部在一起训练,不会固定任何参数
2:这个在存储时如果直接存储,sentence-transformers无法导入,因为不是人家要求的格式,你需要:去看下带有全连接层的SentenceTransformer模型的目录构造和文件内容以及权重的名字,搞清楚他们的结构,然后:手动的读取你直接存储的模型的权重文件中的全连接层权重,单独做一个2_Dense下的pytorch_model.bin,然后根据参数维度增加或修改配置文件。
相关代码段:
## 1_Pooling
os.makedirs(os.path.join(ckpt_dir, "1_Pooling"), exist_ok=True)
with open(os.path.join(ckpt_dir, "1_Pooling", "config.json"), "w", encoding="utf8") as fw:
json.dump({
"word_embedding_dimension": model_conf.hidden_size,
"pooling_mode_cls_token": False,
"pooling_mode_mean_tokens": True,
"pooling_mode_max_tokens": False,
"pooling_mode_mean_sqrt_len_tokens": False
}, fw, ensure_ascii=False, indent=1)
## 2_Dense
os.makedirs(os.path.join(ckpt_dir, "2_Dense"), exist_ok=True)
with open(os.path.join(ckpt_dir, "2_Dense", "config.json"), "w", encoding="utf8") as fw:
json.dump(
{
"in_features": model_conf.hidden_size,
"out_features": VEC_DIM,
"bias": True,
"activation_function": "torch.nn.modules.linear.Identity"
},
fw,
ensure_ascii=False,
indent=1
)
model_di = torch.load(os.path.join(ckpt_dir, "pytorch_model.bin"))
key_w, key_b = "", ""
for k in model_di:
if "vec_linear.weight" in k:
key_w = k
if "vec_linear.bias" in k:
key_b = k
torch.save(
{
"linear.weight": model_di[key_w].clone().detach(),
"linear.bias": model_di[key_b].clone().detach()
},
os.path.join(ckpt_dir, "2_Dense", "pytorch_model.bin")
)
## modules.json
with open(os.path.join(ckpt_dir, "modules.json"), "w", encoding="utf8") as fw:
json.dump(
[
{
"idx": 0,
"name": "0",
"path": "",
"type": "sentence_transformers.models.Transformer"
},
{
"idx": 1,
"name": "1",
"path": "1_Pooling",
"type": "sentence_transformers.models.Pooling"
},
{
"idx": 2,
"name": "2",
"path": "2_Dense",
"type": "sentence_transformers.models.Dense"
}
],
fw,
ensure_ascii=False,
indent=1
)
## sentence_bert_config.json
with open(os.path.join(ckpt_dir, "sentence_bert_config.json"), "w", encoding="utf8") as fw:
json.dump(
{
"max_seq_length": 512,
"do_lower_case": False,
},
fw,
ensure_ascii=False,
indent=1
)