import torch from torch import nn as nn from torch.nn import functional as F from torchmetrics.functional import accuracy from bitfount import PyTorchBitfountModel, PyTorchClassifierMixIn # Update the class name for your Custom model class MyCustomModel(PyTorchClassifierMixIn, PyTorchBitfountModel): # A custom model built using PyTorch Lightning. def __init__(self, **kwargs): super().__init__(**kwargs) self.learning_rate = 0.001 # Initializes the model and sets hyperparameters. # We need to call the parent __init__ first to ensure base model is set up. # Then we can set our custom model parameters. def create_model(self): self.input_size = self.datastructure.input_size return nn.Sequential( nn.Linear(self.input_size, 500), nn.ReLU(), nn.Dropout(0.1), nn.Linear(500, self.n_classes), ) def forward(self, x): # Defines the operations we want to use for prediction. x, sup = x x = self._model(x.float()) return x def training_step(self, batch, batch_idx): # Computes and returns the training loss for a batch of data. x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss def validation_step(self, batch, batch_idx): # Operates on a single batch of data from the validation set. x, y = batch preds = self(x) loss = F.cross_entropy(preds, y) preds = F.softmax(preds, dim=1) acc = accuracy(preds, y) # We can log out some useful stats so we can see progress self.log("val_loss", loss, prog_bar=True) self.log("val_acc", acc, prog_bar=True) return { "val_loss": loss, "val_acc": acc, } def test_step(self, batch, batch_idx): x, y = batch preds = self(x) preds = F.softmax(preds, dim=1) # Output targets and prediction for later return {"predictions": preds, "targets": y} def configure_optimizers(self): # Configure the optimizer we wish to use whilst training. optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) return optimizer