File size: 845 Bytes
3b413ba 0ae896e 3b413ba 0ae896e 3b413ba d9649c3 3b413ba 898adc4 9c570b4 77e83ce 0728fa3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
from torch import nn
from transformers import MobileBertPreTrainedModel, MobileBertModel
class SimModel(MobileBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
# Initialize weights and apply final processing
self.post_init()
def forward(self, input_ids, attention_mask, return_dict):
print(input_ids)
print(attention_mask)
embeddings = self.word_embeddings(input_ids)
masked_embeddings = embeddings * attention_mask[:, :, None]
mean_pooled_embeddings = masked_embeddings.sum(dim=1) / attention_mask[:, :, None].sum(dim=1)
return (embeddings, mean_pooled_embeddings, embeddings) |