glove-wiki-gigaword-50 / modeling_word2vec.py
Iseratho's picture
add model
47ef5ed
from transformers import PreTrainedModel, modeling_outputs
from torch import nn
import torch
from .configuration_word2vec import PretrainedWord2VecHFConfig
class PretrainedWord2VecHFModel(PreTrainedModel):
config_class = PretrainedWord2VecHFConfig
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.num_words, config.vector_size)
def set_embeddings(self, embeddings):
self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings))
def forward(self, input_ids, **kwargs):
if type(input_ids) != torch.tensor: # e.g., list or np.array
input_ids = torch.tensor(input_ids)
x = self.embeddings(input_ids)
return modeling_outputs.BaseModelOutput(last_hidden_state=x)