File size: 809 Bytes
65423b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e4e0da
 
0fc1f17
3e4e0da
 
 
 
0fc1f17
3e4e0da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch.nn as nn
import torch.nn.functional as F

class TextClassifierModel(nn.Module):
    def __init__(self, vocab_size, embed_size, num_class):
        super(TextClassifierModel, self).__init__()

        self.embedding = nn.EmbeddingBag(vocab_size, embed_size)
        self.bn1 = nn.BatchNorm1d(embed_size)
        self.fc = nn.Linear(embed_size, num_class)


    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        embedded_norm = self.bn1(embedded)
        embedded_activated = F.relu(embedded_norm)

        return self.fc(embedded_activated)


def load_state_dict(new_model, trained_model, vocab):
    num_class = 11
    vocab_size = len(vocab)
    embed_size = 300
    new_model.load_state_dict(trained_model['model_state_dict'])
    return new_model