File size: 956 Bytes
c1c5bd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import DistilBertForSequenceClassification
from torch import nn

class DialogueManagerModel(nn.Module):
    DEFAULT_MODEL = "distilbert-base-uncased"

    def __init__(self, n_classes, model_name=None, device='cpu'):
        super().__init__()
        if model_name is None:
            self.model = DistilBertForSequenceClassification.from_pretrained(self.DEFAULT_MODEL)
        else:
            raise NotImplementedError()
        self.model.to(device)
        self.n_classes = n_classes
        self.freeze_layers()
        self.model.classifier = nn.Linear(self.model.classifier.in_features, self.n_classes,
                                          device=device)

        for param in self.model.classifier.parameters():
            param.requires_grad = True

    def freeze_layers(self):
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, X):
        return self.model(X)