Iseratho commited on
Commit
47ef5ed
1 Parent(s): 9abce19
Files changed (1) hide show
  1. modeling_word2vec.py +5 -3
modeling_word2vec.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import PreTrainedModel
2
  from torch import nn
3
  import torch
4
  from .configuration_word2vec import PretrainedWord2VecHFConfig
@@ -14,5 +14,7 @@ class PretrainedWord2VecHFModel(PreTrainedModel):
14
  self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings))
15
 
16
  def forward(self, input_ids, **kwargs):
17
- x = self.embeddings(torch.tensor(input_ids))
18
- return x
 
 
1
+ from transformers import PreTrainedModel, modeling_outputs
2
  from torch import nn
3
  import torch
4
  from .configuration_word2vec import PretrainedWord2VecHFConfig
14
  self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings))
15
 
16
  def forward(self, input_ids, **kwargs):
17
+ if type(input_ids) != torch.tensor: # e.g., list or np.array
18
+ input_ids = torch.tensor(input_ids)
19
+ x = self.embeddings(input_ids)
20
+ return modeling_outputs.BaseModelOutput(last_hidden_state=x)