duzx16 commited on
Commit
7c48048
1 Parent(s): b1502f4

Add classifier

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. configuration_chatglm.py +2 -0
  3. modeling_chatglm.py +88 -0
config.json CHANGED
@@ -8,7 +8,8 @@
8
  "AutoConfig": "configuration_chatglm.ChatGLMConfig",
9
  "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
10
  "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
11
- "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration"
 
12
  },
13
  "add_bias_linear": false,
14
  "add_qkv_bias": true,
 
8
  "AutoConfig": "configuration_chatglm.ChatGLMConfig",
9
  "AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
10
  "AutoModelForCausalLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
11
+ "AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration",
12
+ "AutoModelForSequenceClassification": "modeling_chatglm.ChatGLMForSequenceClassification"
13
  },
14
  "add_bias_linear": false,
15
  "add_qkv_bias": true,
configuration_chatglm.py CHANGED
@@ -13,6 +13,7 @@ class ChatGLMConfig(PretrainedConfig):
13
  num_attention_heads=32,
14
  seq_length=2048,
15
  hidden_dropout=0.0,
 
16
  attention_dropout=0.0,
17
  layernorm_epsilon=1e-5,
18
  rmsnorm=True,
@@ -40,6 +41,7 @@ class ChatGLMConfig(PretrainedConfig):
40
  self.num_attention_heads = num_attention_heads
41
  self.seq_length = seq_length
42
  self.hidden_dropout = hidden_dropout
 
43
  self.attention_dropout = attention_dropout
44
  self.layernorm_epsilon = layernorm_epsilon
45
  self.rmsnorm = rmsnorm
 
13
  num_attention_heads=32,
14
  seq_length=2048,
15
  hidden_dropout=0.0,
16
+ classifier_dropout=None,
17
  attention_dropout=0.0,
18
  layernorm_epsilon=1e-5,
19
  rmsnorm=True,
 
41
  self.num_attention_heads = num_attention_heads
42
  self.seq_length = seq_length
43
  self.hidden_dropout = hidden_dropout
44
+ self.classifier_dropout = classifier_dropout
45
  self.attention_dropout = attention_dropout
46
  self.layernorm_epsilon = layernorm_epsilon
47
  self.rmsnorm = rmsnorm
modeling_chatglm.py CHANGED
@@ -11,12 +11,14 @@ import torch.utils.checkpoint
11
  import torch.nn.functional as F
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss, LayerNorm
 
14
  from torch.nn.utils import skip_init
15
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
16
 
17
  from transformers.modeling_outputs import (
18
  BaseModelOutputWithPast,
19
  CausalLMOutputWithPast,
 
20
  )
21
  from transformers.modeling_utils import PreTrainedModel
22
  from transformers.utils import logging
@@ -1191,3 +1193,89 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1191
  self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1192
  **kwargs)
1193
  return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import torch.nn.functional as F
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss, LayerNorm
14
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
15
  from torch.nn.utils import skip_init
16
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
17
 
18
  from transformers.modeling_outputs import (
19
  BaseModelOutputWithPast,
20
  CausalLMOutputWithPast,
21
+ SequenceClassifierOutputWithPast,
22
  )
23
  from transformers.modeling_utils import PreTrainedModel
24
  from transformers.utils import logging
 
1193
  self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1194
  **kwargs)
1195
  return self
1196
+
1197
+
1198
+ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1199
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1200
+ super().__init__(config)
1201
+
1202
+ self.num_labels = config.num_labels
1203
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1204
+
1205
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1206
+ if config.classifier_dropout is not None:
1207
+ self.dropout = nn.Dropout(config.classifier_dropout)
1208
+ else:
1209
+ self.dropout = None
1210
+ self.config = config
1211
+
1212
+ if self.config.quantization_bit:
1213
+ self.quantize(self.config.quantization_bit, empty_init=True)
1214
+
1215
+ def forward(
1216
+ self,
1217
+ input_ids: Optional[torch.LongTensor] = None,
1218
+ position_ids: Optional[torch.LongTensor] = None,
1219
+ attention_mask: Optional[torch.Tensor] = None,
1220
+ full_attention_mask: Optional[torch.Tensor] = None,
1221
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1222
+ inputs_embeds: Optional[torch.LongTensor] = None,
1223
+ labels: Optional[torch.LongTensor] = None,
1224
+ use_cache: Optional[bool] = None,
1225
+ output_hidden_states: Optional[bool] = None,
1226
+ return_dict: Optional[bool] = None,
1227
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1228
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1229
+
1230
+ transformer_outputs = self.transformer(
1231
+ input_ids=input_ids,
1232
+ position_ids=position_ids,
1233
+ attention_mask=attention_mask,
1234
+ full_attention_mask=full_attention_mask,
1235
+ past_key_values=past_key_values,
1236
+ inputs_embeds=inputs_embeds,
1237
+ use_cache=use_cache,
1238
+ output_hidden_states=output_hidden_states,
1239
+ return_dict=return_dict,
1240
+ )
1241
+
1242
+ hidden_states = transformer_outputs[0]
1243
+ pooled_hidden_states = hidden_states[-1]
1244
+ if self.dropout is not None:
1245
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
1246
+ logits = self.classifier_head(pooled_hidden_states)
1247
+
1248
+ loss = None
1249
+ if labels is not None:
1250
+ if self.config.problem_type is None:
1251
+ if self.num_labels == 1:
1252
+ self.config.problem_type = "regression"
1253
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1254
+ self.config.problem_type = "single_label_classification"
1255
+ else:
1256
+ self.config.problem_type = "multi_label_classification"
1257
+
1258
+ if self.config.problem_type == "regression":
1259
+ loss_fct = MSELoss()
1260
+ if self.num_labels == 1:
1261
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1262
+ else:
1263
+ loss = loss_fct(logits.float(), labels)
1264
+ elif self.config.problem_type == "single_label_classification":
1265
+ loss_fct = CrossEntropyLoss()
1266
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1267
+ elif self.config.problem_type == "multi_label_classification":
1268
+ loss_fct = BCEWithLogitsLoss()
1269
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1270
+
1271
+ if not return_dict:
1272
+ output = (logits,) + transformer_outputs[1:]
1273
+ return ((loss,) + output) if loss is not None else output
1274
+
1275
+ return SequenceClassifierOutputWithPast(
1276
+ loss=loss,
1277
+ logits=logits,
1278
+ past_key_values=transformer_outputs.past_key_values,
1279
+ hidden_states=transformer_outputs.hidden_states,
1280
+ attentions=transformer_outputs.attentions,
1281
+ )