EBERT_ru_MLM / ebert_model.py
Darkester's picture
Upload 2 files
b61b1b5 verified
from transformers import BertConfig, BertModel
import torch.nn as nn
class EBertConfig(BertConfig):
model_type = "ebert"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.adapter_size = kwargs.pop('adapter_size', None)
class EBertModel(BertModel):
config_class = EBertConfig
def __init__(self, config: EBertConfig):
super().__init__(config)
if config.adapter_size:
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Linear(config.hidden_size, config.adapter_size),
nn.ReLU(),
nn.Linear(config.adapter_size, config.hidden_size),
)
for _ in range(config.num_hidden_layers)
])
else:
self.adapters = None
def forward(self, *args, **kwargs):
outputs = super().forward(*args, **kwargs)
sequence_output = outputs.last_hidden_state
if self.adapters is not None:
for adapter in self.adapters:
sequence_output = sequence_output + adapter(sequence_output)
return outputs.__class__(
last_hidden_state=sequence_output,
pooler_output=outputs.pooler_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)