bristena-op
commited on
Commit
•
c26316d
1
Parent(s):
507e341
try model
Browse files- MyCustomModel.py +64 -0
MyCustomModel.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torchmetrics.functional import accuracy
|
5 |
+
from bitfount import PyTorchBitfountModel, PyTorchClassifierMixIn
|
6 |
+
# Update the class name for your Custom model
|
7 |
+
class MyCustomModel(PyTorchClassifierMixIn, PyTorchBitfountModel):
|
8 |
+
# A custom model built using PyTorch Lightning.
|
9 |
+
def __init__(self, **kwargs):
|
10 |
+
super().__init__(**kwargs)
|
11 |
+
self.learning_rate = 0.001
|
12 |
+
# Initializes the model and sets hyperparameters.
|
13 |
+
# We need to call the parent __init__ first to ensure base model is set up.
|
14 |
+
# Then we can set our custom model parameters.
|
15 |
+
|
16 |
+
def create_model(self):
|
17 |
+
self.input_size = self.datastructure.input_size
|
18 |
+
return nn.Sequential(
|
19 |
+
nn.Linear(self.input_size, 500),
|
20 |
+
nn.ReLU(),
|
21 |
+
nn.Dropout(0.1),
|
22 |
+
nn.Linear(500, self.n_classes),
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
# Defines the operations we want to use for prediction.
|
27 |
+
x, sup = x
|
28 |
+
x = self._model(x.float())
|
29 |
+
return x
|
30 |
+
|
31 |
+
def training_step(self, batch, batch_idx):
|
32 |
+
# Computes and returns the training loss for a batch of data.
|
33 |
+
x, y = batch
|
34 |
+
y_hat = self(x)
|
35 |
+
loss = F.cross_entropy(y_hat, y)
|
36 |
+
return loss
|
37 |
+
|
38 |
+
def validation_step(self, batch, batch_idx):
|
39 |
+
# Operates on a single batch of data from the validation set.
|
40 |
+
x, y = batch
|
41 |
+
preds = self(x)
|
42 |
+
loss = F.cross_entropy(preds, y)
|
43 |
+
preds = F.softmax(preds, dim=1)
|
44 |
+
acc = accuracy(preds, y)
|
45 |
+
# We can log out some useful stats so we can see progress
|
46 |
+
self.log("val_loss", loss, prog_bar=True)
|
47 |
+
self.log("val_acc", acc, prog_bar=True)
|
48 |
+
return {
|
49 |
+
"val_loss": loss,
|
50 |
+
"val_acc": acc,
|
51 |
+
}
|
52 |
+
|
53 |
+
def test_step(self, batch, batch_idx):
|
54 |
+
x, y = batch
|
55 |
+
preds = self(x)
|
56 |
+
preds = F.softmax(preds, dim=1)
|
57 |
+
|
58 |
+
# Output targets and prediction for later
|
59 |
+
return {"predictions": preds, "targets": y}
|
60 |
+
|
61 |
+
def configure_optimizers(self):
|
62 |
+
# Configure the optimizer we wish to use whilst training.
|
63 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
|
64 |
+
return optimizer
|