import logging from typing import Optional, Tuple from typing import Union import torch from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel, BertForMaskedLM, BertConfig from transformers.modeling_outputs import SequenceClassifierOutput class StanceEncoderModel(PreTrainedModel): config_class = BertConfig logger = logging.getLogger("StanceEncoderModel") def __init__(self, config): super().__init__(config) task_specific_params = config.task_specific_params self.num_labels = task_specific_params.get('num_labels', 3) self.mask_token_id = task_specific_params['mask_token_id'] self.verbalizer_token_ids = task_specific_params['verbalizer_token_ids'] self.clf_hidden_dim = task_specific_params.get('clf_hidden_dim', 300) self.clf_drop_prob = task_specific_params.get('clf_drop_prob', 0.2) self.clf_gelu_head = task_specific_params.get('clf_gelu_head', False) self.masked_lm = task_specific_params.get('masked_lm', True) self.masked_lm_n_tokens = task_specific_params.get('masked_lm_tokens', 1) self.masked_lm_verbalizer = task_specific_params.get('masked_lm_verbalizer', False) base_model = BertForMaskedLM(config) self.base_enc_model = base_model.bert self.lm_head = base_model.cls hidden_size_multiplier = 1 if not self.masked_lm_verbalizer: if self.clf_gelu_head: self.logger.info('using 2 layer gelu classifier head') self.classifier = torch.nn.Sequential( torch.nn.Linear(self.config.hidden_size * hidden_size_multiplier, self.clf_hidden_dim), torch.nn.Dropout(self.clf_drop_prob), torch.nn.GELU(), torch.nn.Linear(self.clf_hidden_dim, self.num_labels) ) else: raise ValueError('classification type head not specified') def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, sequence_ids: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: outputs = self.base_enc_model( input_ids=input_ids, attention_mask=attention_mask ) masked_token_filter = input_ids == self.mask_token_id masked_repr = outputs.last_hidden_state[masked_token_filter].reshape(len(input_ids), -1) if self.masked_lm_verbalizer: logits = self.lm_head(masked_repr)[:, self.verbalizer_token_ids] else: logits = self.classifier(masked_repr) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )