为什么计算softmax之前要将logits转为float?

#10
by yuanshuai - opened
    hidden_states = outputs[0]
    logits = self.lm_head(hidden_states)
    logits = logits.float()

modeling中为什么有这样的操作,由于词表过大,序列长度较长时这一步显著增加了显存的开销。请问这样操作是出于什么考虑呢?

Sign up or log in to comment