|
from transformers.models.electra.modeling_electra import ElectraPreTrainedModel, ElectraEncoder, ElectraLayer, \ |
|
ModelOutput, ElectraForSequenceClassification, SequenceClassifierOutput, ElectraForTokenClassification, \ |
|
ElectraForMultipleChoice |
|
from .config import CharmenElectraConfig |
|
from .gbst import GBST |
|
import torch.nn as nn |
|
import copy |
|
import torch |
|
from torch import Tensor |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
from typing import OrderedDict as OrderDictType |
|
from collections import OrderedDict |
|
from transformers.activations import get_activation |
|
|
|
|
|
@dataclass |
|
class CharmenElectraModelOutput(ModelOutput): |
|
""" |
|
Output type of :class:`~.CharmenElectraModel`. |
|
""" |
|
downsampled_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
upsampled_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
class CharmenElectraModel(ElectraPreTrainedModel): |
|
config_class = CharmenElectraConfig |
|
|
|
def __init__(self, config: CharmenElectraConfig, compatibility_with_transformers=False, **kwargs): |
|
super().__init__(config) |
|
self.embeddings: GBST = GBST( |
|
num_tokens=config.vocab_size, |
|
|
|
dim=config.embedding_size, |
|
max_block_size=config.max_block_size, |
|
downsample_factor=config.downsampling_factor, |
|
|
|
score_consensus_attn=config.score_consensus_attn, |
|
config=config |
|
|
|
) |
|
self.compatibility_with_transformers = compatibility_with_transformers |
|
|
|
if config.embedding_size != config.hidden_size: |
|
self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size) |
|
|
|
self.upsampling = nn.Upsample(scale_factor=config.downsampling_factor, mode='nearest') |
|
self.upsampling_convolution = nn.Conv1d(in_channels=config.hidden_size * 2, |
|
out_channels=config.hidden_size, |
|
kernel_size=(config.downsampling_factor*2-1,), |
|
padding='same', |
|
dilation=(1,)) |
|
self.upsample_output = config.upsample_output |
|
|
|
|
|
|
|
cfg = copy.deepcopy(config) |
|
cfg.num_hidden_layers = config.num_hidden_layers - 2 |
|
self.encoder = ElectraEncoder(cfg) |
|
|
|
|
|
self.encoder_first_layer = ElectraLayer(config) |
|
self.encoder_last_layer = ElectraLayer(config) |
|
|
|
self.config = config |
|
self.init_weights() |
|
|
|
def get_input_embeddings(self): |
|
return self.embeddings.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
self.embeddings.word_embeddings = value |
|
|
|
def _prune_heads(self, heads_to_prune): |
|
""" |
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
|
class PreTrainedModel |
|
""" |
|
for layer, heads in heads_to_prune.items(): |
|
self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids.shape.__len__() == 1: |
|
input_ids = input_ids.view(1, -1) |
|
attention_mask = attention_mask.view(1, -1) |
|
token_type_ids = token_type_ids.view(1, -1) |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
batch_size, seq_length = input_shape |
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones(input_shape, device=device) |
|
if token_type_ids is None: |
|
if hasattr(self.embeddings, "token_type_ids"): |
|
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] |
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) |
|
token_type_ids = buffered_token_type_ids_expanded |
|
else: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
|
unscaled_attention_mask = torch.clone(attention_mask) |
|
|
|
_, _, unscaled_hidden_states = self.embeddings( |
|
input_ids=input_ids, attention_mask=attention_mask, |
|
position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds |
|
) |
|
|
|
if hasattr(self, "embeddings_project"): |
|
unscaled_hidden_states = self.embeddings_project(unscaled_hidden_states) |
|
|
|
extended_unscaled_attention_mask = self.get_extended_attention_mask(unscaled_attention_mask, input_shape, |
|
device) |
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
unscaled_hidden_states = self.encoder_first_layer(unscaled_hidden_states, extended_unscaled_attention_mask, |
|
None, None, None, None, False)[0] |
|
|
|
hidden_states, attention_mask = self.embeddings.down_sample(unscaled_hidden_states, unscaled_attention_mask, |
|
self.config.downsampling_factor) |
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) |
|
|
|
encoder_output = self.encoder( |
|
hidden_states, |
|
attention_mask=extended_attention_mask, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
downsampled_hidden_states = encoder_output[0] |
|
hidden_states = encoder_output[0] |
|
|
|
|
|
upsampled = self.upsampling(hidden_states.permute(0, 2, 1)).permute(0, 2, 1) |
|
hidden_states = torch.cat([unscaled_hidden_states, upsampled], dim=-1) |
|
|
|
hidden_states = self.upsampling_convolution(hidden_states.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
hidden_states = self.encoder_last_layer(hidden_states, extended_unscaled_attention_mask, |
|
None, None, None, None, False) |
|
|
|
upsampled_output = hidden_states[0] |
|
|
|
return CharmenElectraModelOutput( |
|
downsampled_hidden_states=downsampled_hidden_states, |
|
upsampled_hidden_states=upsampled_output |
|
) |
|
|
|
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True): |
|
model = OrderedDict() |
|
prefix = "discriminator.electra." |
|
|
|
for key, value in state_dict.items(): |
|
if key.startswith('generator'): |
|
continue |
|
if key.startswith(prefix): |
|
model[key[len(prefix):]] = value |
|
else: |
|
continue |
|
|
|
super(CharmenElectraModel, self).load_state_dict(state_dict=model, strict=strict) |
|
|
|
|
|
class CharmenElectraClassificationHead(nn.Module): |
|
"""Head for sentence-level classification tasks.""" |
|
|
|
def __init__(self, config: CharmenElectraConfig, **kwargs): |
|
super().__init__() |
|
self.config = config |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
self.ds_factor = config.downsampling_factor |
|
|
|
def forward(self, features, **kwargs): |
|
x = features[:, 0, :] |
|
x = self.dropout(x) |
|
x = self.dense(x) |
|
x = get_activation(self.config.summary_activation)(x) |
|
x = self.dropout(x) |
|
x = self.out_proj(x) |
|
return x |
|
|
|
|
|
class CharmenElectraForSequenceClassification(ElectraForSequenceClassification): |
|
config_class = CharmenElectraConfig |
|
|
|
def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0, **kwargs): |
|
super().__init__(config) |
|
|
|
self.num_labels = config.num_labels |
|
self.config = config |
|
self.electra = CharmenElectraModel(config, compatibility_with_transformers=True) |
|
self.classifier = CharmenElectraClassificationHead(config) |
|
self.cls_loss_fct = torch.nn.CrossEntropyLoss(weight=class_weight, label_smoothing=label_smoothing) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
output_discriminator: CharmenElectraModelOutput = self.electra(input_ids, attention_mask, token_type_ids) |
|
|
|
if self.carmen_config.upsample_output: |
|
cls = self.classifier(output_discriminator.upsampled_hidden_states) |
|
else: |
|
cls = self.classifier(output_discriminator.downsampled_hidden_states) |
|
cls_loss = self.cls_loss_fct(cls, labels) |
|
|
|
return SequenceClassifierOutput( |
|
loss=cls_loss, |
|
logits=cls, |
|
hidden_states=output_discriminator.downsampled_hidden_states, |
|
attentions=None, |
|
) |
|
|
|
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True): |
|
model = OrderedDict() |
|
prefix = "discriminator." |
|
|
|
for key, value in state_dict.items(): |
|
if key.startswith('generator'): |
|
continue |
|
if key.startswith(prefix): |
|
if 'discriminator_predictions' in key: |
|
continue |
|
model[key[len(prefix):]] = value |
|
else: |
|
if key.startswith('sop'): |
|
continue |
|
model[key] = value |
|
|
|
super(CharmenElectraForSequenceClassification, self).load_state_dict(state_dict=model, strict=False) |
|
|
|
|
|
class CharmenElectraForTokenClassification(ElectraForTokenClassification): |
|
def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0, **kwargs): |
|
super().__init__(config) |
|
|
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
self.carmen_config = config |
|
self.electra = CharmenElectraModel(config, compatibility_with_transformers=True) |
|
|
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
self.cls_loss_fct = torch.nn.CrossEntropyLoss(weight=class_weight, label_smoothing=label_smoothing) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
output_discriminator: CharmenElectraModelOutput = self.electra( |
|
input_ids, attention_mask, token_type_ids) |
|
|
|
discriminator_sequence_output = self.dropout(output_discriminator.upsampled_hidden_states) |
|
logits = self.classifier(discriminator_sequence_output) |
|
|
|
if labels is not None: |
|
cls_loss = self.cls_loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
|
else: |
|
cls_loss = None |
|
|
|
return SequenceClassifierOutput( |
|
loss=cls_loss, |
|
logits=logits, |
|
hidden_states=output_discriminator.upsampled_hidden_states, |
|
attentions=None, |
|
) |
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
return self.electra.get_input_embeddings() |
|
|
|
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True): |
|
model = OrderedDict() |
|
prefix = "discriminator." |
|
|
|
for key, value in state_dict.items(): |
|
if key.startswith('generator'): |
|
continue |
|
if key.startswith(prefix): |
|
if 'discriminator_predictions' in key: |
|
continue |
|
model[key[len(prefix):]] = value |
|
else: |
|
if key.startswith('sop'): |
|
continue |
|
model[key] = value |
|
|
|
super(CharmenElectraForTokenClassification, self).load_state_dict(state_dict=model, strict=False) |
|
|
|
|
|
class Pooler(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.activation = nn.Tanh() |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
first_token_tensor = hidden_states[:, 0] |
|
pooled_output = self.dense(first_token_tensor) |
|
pooled_output = self.activation(pooled_output) |
|
return pooled_output |
|
|
|
|
|
class CharmenElectraForMultipleChoice(ElectraForMultipleChoice): |
|
def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0, **kwargs): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.config = config |
|
self.carmen_config = config |
|
self.electra = CharmenElectraModel(config, compatibility_with_transformers=True) |
|
self.pooler = Pooler(config) |
|
|
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
self.classifier = nn.Linear(config.hidden_size, 1) |
|
|
|
self.cls_loss_fct = torch.nn.CrossEntropyLoss(weight=class_weight, label_smoothing=label_smoothing) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
|
|
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None |
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None |
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None |
|
|
|
output_discriminator: CharmenElectraModelOutput = self.electra( |
|
input_ids, attention_mask, token_type_ids) |
|
|
|
if self.carmen_config.upsample_output: |
|
pooled_output = self.pooler(output_discriminator.upsampled_hidden_states) |
|
else: |
|
pooled_output = self.pooler(output_discriminator.downsampled_hidden_states) |
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
reshaped_logits = logits.view(-1, num_choices) |
|
|
|
cls_loss = self.cls_loss_fct(reshaped_logits, labels) |
|
|
|
return SequenceClassifierOutput( |
|
loss=cls_loss, |
|
logits=reshaped_logits, |
|
hidden_states=output_discriminator.downsampled_hidden_states, |
|
attentions=None, |
|
) |
|
|
|
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True): |
|
model = OrderedDict() |
|
prefix = "discriminator." |
|
|
|
for key, value in state_dict.items(): |
|
if key.startswith('generator'): |
|
continue |
|
if key.startswith(prefix): |
|
if 'discriminator_predictions' in key: |
|
continue |
|
model[key[len(prefix):]] = value |
|
else: |
|
if key.startswith('sop'): |
|
continue |
|
model[key] = value |
|
|
|
super(CharmenElectraForMultipleChoice, self).load_state_dict(state_dict=model, strict=False) |
|
|