gLM2_150M-promoter_tata-lora / extension_glm2.py
alejandralopezsosa's picture
init commit
ea5bbeb
import torch
import torch.nn as nn
from transformers.modeling_outputs import (
BaseModelOutput,
SequenceClassifierOutput,
)
from typing import Optional, Union, Tuple
from .configuration_glm2 import gLM2Config
from .modeling_glm2 import gLM2Model, gLM2PreTrainedModel
from transformers import PretrainedConfig
from typing import List
class gLM2ClassicationConfig(gLM2Config):
def __init__(self, num_classes: int = 2, **kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
self.auto_map['AutoModelForSequenceClassification'] = "extension_glm2.gLM2ForSequenceClassification"
class gLM2ForSequenceClassification(gLM2PreTrainedModel):
config_class = gLM2ClassicationConfig
def __init__(self, config: gLM2ClassicationConfig):
super().__init__(config)
self.glm2 = gLM2Model(config)
self.score = nn.Linear(config.dim, config.num_classes, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.glm2.tok_embeddings
def set_input_embeddings(self, value):
self.glm2.tok_embeddings = value
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.glm2(
input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
token_embeddings = outputs[0]
# use <+> as CLS token
cls_token = token_embeddings[:, 0, :]
logits = self.score(cls_token)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)