Update modeling_sim.py
Browse files- modeling_sim.py +2 -4
modeling_sim.py
CHANGED
@@ -6,12 +6,10 @@ class SimModel(MobileBertPreTrainedModel):
|
|
6 |
def __init__(self, config):
|
7 |
super().__init__(config)
|
8 |
self.config = config
|
9 |
-
self.
|
10 |
-
print(self.encoder)
|
11 |
# Initialize weights and apply final processing
|
12 |
self.post_init()
|
13 |
|
14 |
def forward(self, input_ids, attention_mask, token_type_ids, return_dict):
|
15 |
print(input_ids, attention_mask, token_type_ids)
|
16 |
-
|
17 |
-
return self.encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=return_dict)
|
|
|
6 |
def __init__(self, config):
|
7 |
super().__init__(config)
|
8 |
self.config = config
|
9 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
|
|
|
10 |
# Initialize weights and apply final processing
|
11 |
self.post_init()
|
12 |
|
13 |
def forward(self, input_ids, attention_mask, token_type_ids, return_dict):
|
14 |
print(input_ids, attention_mask, token_type_ids)
|
15 |
+
return self.word_embeddings[input_ids]
|
|