| import torch |
| from torch import nn |
|
|
| class Finetune(nn.Module): |
| def __init__(self, backbone, feat_dim, num_class, **kwargs): |
| super().__init__() |
| self.backbone = backbone |
| self.feat_dim = feat_dim |
| self.num_class = num_class |
| self.classifier = nn.Linear(feat_dim, num_class) |
| self.loss_fn = nn.CrossEntropyLoss(reduction='mean') |
| self.device = kwargs['device'] |
| self.kwargs = kwargs |
| |
| def observe(self, data): |
| x, y = data['image'], data['label'] |
| x = x.to(self.device) |
| y = y.to(self.device) |
| logit = self.classifier(self.backbone(x)['features']) |
| loss = self.loss_fn(logit, y) |
|
|
| pred = torch.argmax(logit, dim=1) |
|
|
| acc = torch.sum(pred == y).item() |
| return pred, acc / x.size(0), loss |
|
|
| def inference(self, data): |
| x, y = data['image'], data['label'] |
| x = x.to(self.device) |
| y = y.to(self.device) |
| |
| logit = self.classifier(self.backbone(x)['features']) |
| pred = torch.argmax(logit, dim=1) |
|
|
| acc = torch.sum(pred == y).item() |
| return pred, acc / x.size(0) |
|
|
| def forward(self, x): |
| return self.classifier(self.backbone(x)['features']) |
| |
| def before_task(self, task_idx, buffer, train_loader, test_loaders): |
| pass |
|
|
| def after_task(self, task_idx, buffer, train_loader, test_loaders): |
| pass |
| |
| def get_parameters(self, config): |
| train_parameters = [] |
| train_parameters.append({"params": self.backbone.parameters()}) |
| train_parameters.append({"params": self.classifier.parameters()}) |
| return train_parameters |
|
|
|
|