stance-pl / models.py
Dawid Motyka
app and model
834d42f
raw
history blame contribute delete
No virus
3.51 kB
import logging
from typing import Optional, Tuple
from typing import Union
import torch
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel, BertForMaskedLM, BertConfig
from transformers.modeling_outputs import SequenceClassifierOutput
class StanceEncoderModel(PreTrainedModel):
config_class = BertConfig
logger = logging.getLogger("StanceEncoderModel")
def __init__(self, config):
super().__init__(config)
task_specific_params = config.task_specific_params
self.num_labels = task_specific_params.get('num_labels', 3)
self.mask_token_id = task_specific_params['mask_token_id']
self.verbalizer_token_ids = task_specific_params['verbalizer_token_ids']
self.clf_hidden_dim = task_specific_params.get('clf_hidden_dim', 300)
self.clf_drop_prob = task_specific_params.get('clf_drop_prob', 0.2)
self.clf_gelu_head = task_specific_params.get('clf_gelu_head', False)
self.masked_lm = task_specific_params.get('masked_lm', True)
self.masked_lm_n_tokens = task_specific_params.get('masked_lm_tokens', 1)
self.masked_lm_verbalizer = task_specific_params.get('masked_lm_verbalizer', False)
base_model = BertForMaskedLM(config)
self.base_enc_model = base_model.bert
self.lm_head = base_model.cls
hidden_size_multiplier = 1
if not self.masked_lm_verbalizer:
if self.clf_gelu_head:
self.logger.info('using 2 layer gelu classifier head')
self.classifier = torch.nn.Sequential(
torch.nn.Linear(self.config.hidden_size * hidden_size_multiplier, self.clf_hidden_dim),
torch.nn.Dropout(self.clf_drop_prob),
torch.nn.GELU(),
torch.nn.Linear(self.clf_hidden_dim, self.num_labels)
)
else:
raise ValueError('classification type head not specified')
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
sequence_ids: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
outputs = self.base_enc_model(
input_ids=input_ids,
attention_mask=attention_mask
)
masked_token_filter = input_ids == self.mask_token_id
masked_repr = outputs.last_hidden_state[masked_token_filter].reshape(len(input_ids), -1)
if self.masked_lm_verbalizer:
logits = self.lm_head(masked_repr)[:, self.verbalizer_token_ids]
else:
logits = self.classifier(masked_repr)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)