duzx16 commited on
Commit
0deb1dd
1 Parent(s): 12c8049

Fix default dtype for classification head

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -1139,7 +1139,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1139
  self.num_labels = config.num_labels
1140
  self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1141
 
1142
- self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1143
  if config.classifier_dropout is not None:
1144
  self.dropout = nn.Dropout(config.classifier_dropout)
1145
  else:
 
1139
  self.num_labels = config.num_labels
1140
  self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1141
 
1142
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=config.torch_dtype)
1143
  if config.classifier_dropout is not None:
1144
  self.dropout = nn.Dropout(config.classifier_dropout)
1145
  else: