File size: 657 Bytes
3b413ba
 
7186847
3b413ba
7319aeb
3b413ba
 
 
65e061d
3b413ba
 
 
196626e
a38759a
ca33b03
08f4fc0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from torch import nn
from transformers import PreTrainedMobileBertModel, MobileBertModel

class SimModel(MobileBertModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.encoder = MobileBertModel(config)
        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, input_ids, attention_mask, token_type_ids, return_dict):
        print(input_ids, attention_mask, token_type_ids)
        print(return_dict)
        return self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=return_dict)