NormHead 中的分支判断

#3
by JaheimLee - opened

您好,请教NormHead的forward中为什么采用三个分支来生成norm_weight啊,直接norm_weight = nn.functional.normalize(self.weight)会有什么问题吗?另外,forward中存在nn.Parameter会使deepspeed报错,可以避免这个问题吗?感谢!

Baichuan Intelligent Technology org

您好,请教NormHead的forward中为什么采用三个分支来生成norm_weight啊,直接norm_weight = nn.functional.normalize(self.weight)会有什么问题吗?另外,forward中存在nn.Parameter会使deepspeed报错,可以避免这个问题吗?感谢!

训练的时候直接norm_weight = nn.functional.normalize(self.weight)是可以的,这么做主要是为了减少计算,提高性能。如果是训练,你可以改成直接normalize的方式也行。

This comment has been hidden

Sign up or log in to comment