bristena-op commited on
Commit
c26316d
1 Parent(s): 507e341
Files changed (1) hide show
  1. MyCustomModel.py +64 -0
MyCustomModel.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+ from torchmetrics.functional import accuracy
5
+ from bitfount import PyTorchBitfountModel, PyTorchClassifierMixIn
6
+ # Update the class name for your Custom model
7
+ class MyCustomModel(PyTorchClassifierMixIn, PyTorchBitfountModel):
8
+ # A custom model built using PyTorch Lightning.
9
+ def __init__(self, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.learning_rate = 0.001
12
+ # Initializes the model and sets hyperparameters.
13
+ # We need to call the parent __init__ first to ensure base model is set up.
14
+ # Then we can set our custom model parameters.
15
+
16
+ def create_model(self):
17
+ self.input_size = self.datastructure.input_size
18
+ return nn.Sequential(
19
+ nn.Linear(self.input_size, 500),
20
+ nn.ReLU(),
21
+ nn.Dropout(0.1),
22
+ nn.Linear(500, self.n_classes),
23
+ )
24
+
25
+ def forward(self, x):
26
+ # Defines the operations we want to use for prediction.
27
+ x, sup = x
28
+ x = self._model(x.float())
29
+ return x
30
+
31
+ def training_step(self, batch, batch_idx):
32
+ # Computes and returns the training loss for a batch of data.
33
+ x, y = batch
34
+ y_hat = self(x)
35
+ loss = F.cross_entropy(y_hat, y)
36
+ return loss
37
+
38
+ def validation_step(self, batch, batch_idx):
39
+ # Operates on a single batch of data from the validation set.
40
+ x, y = batch
41
+ preds = self(x)
42
+ loss = F.cross_entropy(preds, y)
43
+ preds = F.softmax(preds, dim=1)
44
+ acc = accuracy(preds, y)
45
+ # We can log out some useful stats so we can see progress
46
+ self.log("val_loss", loss, prog_bar=True)
47
+ self.log("val_acc", acc, prog_bar=True)
48
+ return {
49
+ "val_loss": loss,
50
+ "val_acc": acc,
51
+ }
52
+
53
+ def test_step(self, batch, batch_idx):
54
+ x, y = batch
55
+ preds = self(x)
56
+ preds = F.softmax(preds, dim=1)
57
+
58
+ # Output targets and prediction for later
59
+ return {"predictions": preds, "targets": y}
60
+
61
+ def configure_optimizers(self):
62
+ # Configure the optimizer we wish to use whilst training.
63
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
64
+ return optimizer