|
import torch |
|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
from torchmetrics.functional import accuracy |
|
from bitfount import PyTorchBitfountModel, PyTorchClassifierMixIn |
|
|
|
class MyCustomModel(PyTorchClassifierMixIn, PyTorchBitfountModel): |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.learning_rate = 0.001 |
|
|
|
|
|
|
|
|
|
def create_model(self): |
|
self.input_size = self.datastructure.input_size |
|
return nn.Sequential( |
|
nn.Linear(self.input_size, 500), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(500, self.n_classes), |
|
) |
|
|
|
def forward(self, x): |
|
|
|
x, sup = x |
|
x = self._model(x.float()) |
|
return x |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
x, y = batch |
|
y_hat = self(x) |
|
loss = F.cross_entropy(y_hat, y) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
x, y = batch |
|
preds = self(x) |
|
loss = F.cross_entropy(preds, y) |
|
preds = F.softmax(preds, dim=1) |
|
acc = accuracy(preds, y) |
|
|
|
self.log("val_loss", loss, prog_bar=True) |
|
self.log("val_acc", acc, prog_bar=True) |
|
return { |
|
"val_loss": loss, |
|
"val_acc": acc, |
|
} |
|
|
|
def test_step(self, batch, batch_idx): |
|
x, y = batch |
|
preds = self(x) |
|
preds = F.softmax(preds, dim=1) |
|
|
|
|
|
return {"predictions": preds, "targets": y} |
|
|
|
def configure_optimizers(self): |
|
|
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) |
|
return optimizer |