File size: 2,273 Bytes
c26316d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn as nn
from torch.nn import functional as F
from torchmetrics.functional import accuracy
from bitfount import PyTorchBitfountModel, PyTorchClassifierMixIn
# Update the class name for your Custom model
class MyCustomModel(PyTorchClassifierMixIn, PyTorchBitfountModel):
    # A custom model built using PyTorch Lightning.
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.learning_rate = 0.001
        # Initializes the model and sets hyperparameters.
        # We need to call the parent __init__ first to ensure base model is set up.
        # Then we can set our custom model parameters.

    def create_model(self):
        self.input_size = self.datastructure.input_size
        return nn.Sequential(
            nn.Linear(self.input_size, 500),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(500, self.n_classes),
        )

    def forward(self, x):
        # Defines the operations we want to use for prediction.
        x, sup = x
        x = self._model(x.float())
        return x

    def training_step(self, batch, batch_idx):
        # Computes and returns the training loss for a batch of data.
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        # Operates on a single batch of data from the validation set.
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        preds = F.softmax(preds, dim=1)
        acc = accuracy(preds, y)
        # We can log out some useful stats so we can see progress
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return {
            "val_loss": loss,
            "val_acc": acc,
        }

    def test_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        preds = F.softmax(preds, dim=1)

        # Output targets and prediction for later
        return {"predictions": preds, "targets": y}

    def configure_optimizers(self):
        # Configure the optimizer we wish to use whilst training.
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer