Spaces:
Runtime error
Runtime error
import logging | |
from typing import Optional, Tuple | |
from typing import Union | |
import torch | |
from torch.nn import CrossEntropyLoss | |
from transformers import PreTrainedModel, BertForMaskedLM, BertConfig | |
from transformers.modeling_outputs import SequenceClassifierOutput | |
class StanceEncoderModel(PreTrainedModel): | |
config_class = BertConfig | |
logger = logging.getLogger("StanceEncoderModel") | |
def __init__(self, config): | |
super().__init__(config) | |
task_specific_params = config.task_specific_params | |
self.num_labels = task_specific_params.get('num_labels', 3) | |
self.mask_token_id = task_specific_params['mask_token_id'] | |
self.verbalizer_token_ids = task_specific_params['verbalizer_token_ids'] | |
self.clf_hidden_dim = task_specific_params.get('clf_hidden_dim', 300) | |
self.clf_drop_prob = task_specific_params.get('clf_drop_prob', 0.2) | |
self.clf_gelu_head = task_specific_params.get('clf_gelu_head', False) | |
self.masked_lm = task_specific_params.get('masked_lm', True) | |
self.masked_lm_n_tokens = task_specific_params.get('masked_lm_tokens', 1) | |
self.masked_lm_verbalizer = task_specific_params.get('masked_lm_verbalizer', False) | |
base_model = BertForMaskedLM(config) | |
self.base_enc_model = base_model.bert | |
self.lm_head = base_model.cls | |
hidden_size_multiplier = 1 | |
if not self.masked_lm_verbalizer: | |
if self.clf_gelu_head: | |
self.logger.info('using 2 layer gelu classifier head') | |
self.classifier = torch.nn.Sequential( | |
torch.nn.Linear(self.config.hidden_size * hidden_size_multiplier, self.clf_hidden_dim), | |
torch.nn.Dropout(self.clf_drop_prob), | |
torch.nn.GELU(), | |
torch.nn.Linear(self.clf_hidden_dim, self.num_labels) | |
) | |
else: | |
raise ValueError('classification type head not specified') | |
def forward( | |
self, | |
input_ids: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
token_type_ids: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.Tensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
inputs_embeds: Optional[torch.Tensor] = None, | |
labels: Optional[torch.Tensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
sequence_ids: Optional[torch.Tensor] = None, | |
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: | |
outputs = self.base_enc_model( | |
input_ids=input_ids, | |
attention_mask=attention_mask | |
) | |
masked_token_filter = input_ids == self.mask_token_id | |
masked_repr = outputs.last_hidden_state[masked_token_filter].reshape(len(input_ids), -1) | |
if self.masked_lm_verbalizer: | |
logits = self.lm_head(masked_repr)[:, self.verbalizer_token_ids] | |
else: | |
logits = self.classifier(masked_repr) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
return SequenceClassifierOutput( | |
loss=loss, | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |