File size: 2,273 Bytes
c26316d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
from torch import nn as nn
from torch.nn import functional as F
from torchmetrics.functional import accuracy
from bitfount import PyTorchBitfountModel, PyTorchClassifierMixIn
# Update the class name for your Custom model
class MyCustomModel(PyTorchClassifierMixIn, PyTorchBitfountModel):
# A custom model built using PyTorch Lightning.
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.learning_rate = 0.001
# Initializes the model and sets hyperparameters.
# We need to call the parent __init__ first to ensure base model is set up.
# Then we can set our custom model parameters.
def create_model(self):
self.input_size = self.datastructure.input_size
return nn.Sequential(
nn.Linear(self.input_size, 500),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(500, self.n_classes),
)
def forward(self, x):
# Defines the operations we want to use for prediction.
x, sup = x
x = self._model(x.float())
return x
def training_step(self, batch, batch_idx):
# Computes and returns the training loss for a batch of data.
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
# Operates on a single batch of data from the validation set.
x, y = batch
preds = self(x)
loss = F.cross_entropy(preds, y)
preds = F.softmax(preds, dim=1)
acc = accuracy(preds, y)
# We can log out some useful stats so we can see progress
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
return {
"val_loss": loss,
"val_acc": acc,
}
def test_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
preds = F.softmax(preds, dim=1)
# Output targets and prediction for later
return {"predictions": preds, "targets": y}
def configure_optimizers(self):
# Configure the optimizer we wish to use whilst training.
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
return optimizer |