|
import torch.nn as nn |
|
from transformers import PreTrainedModel, AutoConfig, AutoModel |
|
|
|
from .configuration_sentence_embedder import SentenceEmbedderConfig |
|
|
|
class SentenceEmbedderModel(PreTrainedModel): |
|
config_class = SentenceEmbedderConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
if config.init_backbone: |
|
self.backbone = AutoModel.from_pretrained(config.backbone_name) |
|
else: |
|
backbone_config = AutoConfig.from_pretrained(config.backbone_name) |
|
self.backbone = AutoModel.from_config(backbone_config) |
|
self.projection = nn.Linear(self.backbone.config.hidden_size, config.output_size) |
|
|
|
|
|
def forward(self, input_ids, attention_mask, head=None): |
|
outputs = self.backbone(input_ids, attention_mask) |
|
last_hidden_state = self.projection(outputs.last_hidden_state) |
|
outputs.last_hidden_state = last_hidden_state |
|
return outputs |