from transformers import PreTrainedModel, modeling_outputs from torch import nn import torch from .configuration_word2vec import PretrainedWord2VecHFConfig class PretrainedWord2VecHFModel(PreTrainedModel): config_class = PretrainedWord2VecHFConfig def __init__(self, config): super().__init__(config) self.embeddings = nn.Embedding(config.num_words, config.vector_size) def set_embeddings(self, embeddings): self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings)) def forward(self, input_ids, **kwargs): if type(input_ids) != torch.tensor: # e.g., list or np.array input_ids = torch.tensor(input_ids) x = self.embeddings(input_ids) return modeling_outputs.BaseModelOutput(last_hidden_state=x)