File size: 2,430 Bytes
ea5bbeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
import torch.nn as nn
from transformers.modeling_outputs import (
BaseModelOutput,
SequenceClassifierOutput,
)
from typing import Optional, Union, Tuple
from .configuration_glm2 import gLM2Config
from .modeling_glm2 import gLM2Model, gLM2PreTrainedModel
from transformers import PretrainedConfig
from typing import List
class gLM2ClassicationConfig(gLM2Config):
def __init__(self, num_classes: int = 2, **kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
self.auto_map['AutoModelForSequenceClassification'] = "extension_glm2.gLM2ForSequenceClassification"
class gLM2ForSequenceClassification(gLM2PreTrainedModel):
config_class = gLM2ClassicationConfig
def __init__(self, config: gLM2ClassicationConfig):
super().__init__(config)
self.glm2 = gLM2Model(config)
self.score = nn.Linear(config.dim, config.num_classes, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.glm2.tok_embeddings
def set_input_embeddings(self, value):
self.glm2.tok_embeddings = value
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.glm2(
input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
token_embeddings = outputs[0]
# use <+> as CLS token
cls_token = token_embeddings[:, 0, :]
logits = self.score(cls_token)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
|