from transformers import PreTrainedModel, PretrainedConfig from torch import nn import torch class PretrainedWord2VecHFConfig(PretrainedConfig): model_type = "glove" def __init__(self, num_words=400001, vector_size=50, **kwargs): self.num_words = num_words self.vector_size = vector_size self.hidden_size = self.vector_size # Required for sBERT super().__init__(**kwargs) class PretrainedWord2VecHFModel(PreTrainedModel): config_class = PretrainedWord2VecHFConfig def __init__(self, config): super().__init__(config) self.embeddings = nn.Embedding(config.num_words, config.vector_size) def set_embeddings(self, embeddings): self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings)) def forward(self, input_ids, **kwargs): x = self.embeddings(torch.tensor(input_ids)) return x