| from transformers import BertConfig, BertModel | |
| import torch.nn as nn | |
| class EBertConfig(BertConfig): | |
| model_type = "ebert" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.adapter_size = kwargs.pop('adapter_size', None) | |
| class EBertModel(BertModel): | |
| config_class = EBertConfig | |
| def __init__(self, config: EBertConfig): | |
| super().__init__(config) | |
| if config.adapter_size: | |
| self.adapters = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Linear(config.hidden_size, config.adapter_size), | |
| nn.ReLU(), | |
| nn.Linear(config.adapter_size, config.hidden_size), | |
| ) | |
| for _ in range(config.num_hidden_layers) | |
| ]) | |
| else: | |
| self.adapters = None | |
| def forward(self, *args, **kwargs): | |
| outputs = super().forward(*args, **kwargs) | |
| sequence_output = outputs.last_hidden_state | |
| if self.adapters is not None: | |
| for adapter in self.adapters: | |
| sequence_output = sequence_output + adapter(sequence_output) | |
| return outputs.__class__( | |
| last_hidden_state=sequence_output, | |
| pooler_output=outputs.pooler_output, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |