dim_model_base参数的作用是什么?
#7
by
ShaneSue
- opened
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base))
logits = logits.float()
看到model的文件里面有这么一段,在tp=1的时候logits除以了一个dim_model_base,请问下这个作用是什么呢