File size: 936 Bytes
6a3434a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from torch import nn
from configuration import CFG
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
class TextEncoder(nn.Module):
def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
super().__init__()
if pretrained:
self.model = DistilBertModel.from_pretrained(model_name)
else:
self.model = DistilBertModel(config=DistilBertConfig())
for p in self.model.parameters():
p.requires_grad = trainable
# we are using the CLS token hidden representation as the sentence's embedding
self.target_token_idx = 0
def forward(self, input_ids, attention_mask):
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :] |