File size: 914 Bytes
d4efd8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import PreTrainedModel, PretrainedConfig
from torch import nn
import torch

class PretrainedWord2VecHFConfig(PretrainedConfig):
    model_type = "glove"
    
    def __init__(self, num_words=400001, vector_size=50, **kwargs):
        self.num_words = num_words
        self.vector_size = vector_size
        self.hidden_size = self.vector_size  # Required for sBERT
        super().__init__(**kwargs)

class PretrainedWord2VecHFModel(PreTrainedModel):
    config_class = PretrainedWord2VecHFConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.embeddings = nn.Embedding(config.num_words, config.vector_size)
        
    def set_embeddings(self, embeddings):
        self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings))
    
    def forward(self, input_ids, **kwargs):
        x = self.embeddings(torch.tensor(input_ids))
        return x