dim_model_base参数的作用是什么?

#7
by ShaneSue - opened
hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base))
        logits = logits.float()

看到model的文件里面有这么一段,在tp=1的时候logits除以了一个dim_model_base,请问下这个作用是什么呢

Sign up or log in to comment