Spaces:
Runtime error
Runtime error
from transformers import DistilBertForSequenceClassification | |
from torch import nn | |
class DialogueManagerModel(nn.Module): | |
DEFAULT_MODEL = "distilbert-base-uncased" | |
def __init__(self, n_classes, model_name=None, device='cpu'): | |
super().__init__() | |
if model_name is None: | |
self.model = DistilBertForSequenceClassification.from_pretrained(self.DEFAULT_MODEL) | |
else: | |
raise NotImplementedError() | |
self.model.to(device) | |
self.n_classes = n_classes | |
self.freeze_layers() | |
self.model.classifier = nn.Linear(self.model.classifier.in_features, self.n_classes, | |
device=device) | |
for param in self.model.classifier.parameters(): | |
param.requires_grad = True | |
def freeze_layers(self): | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
def forward(self, X): | |
return self.model(X) |