File size: 1,144 Bytes
b358118
 
2262103
b358118
 
 
 
2262103
b358118
 
2262103
 
b358118
 
 
 
 
2262103
b358118
 
 
 
 
 
 
 
 
2262103
b358118
 
 
 
 
 
 
 
2262103
b358118
 
 
2262103
 
 
b358118
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#!/usr/bin/env python
# coding: utf-8
from torch import nn, optim
import pytorch_lightning as pl


class LitTrainer(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.optim = optim.Adam(self.parameters(), lr=1e-4)
        self.loss = nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        x, y = batch

        y_pred = self.model(x).reshape(1, -1)
        train_loss = self.loss(y_pred, y)

        self.log("train_loss", train_loss)
        return train_loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch

        y_pred = self.model(x).reshape(1, -1)
        validate_loss = self.loss(y_pred, y)

        self.log("val_loss", validate_loss)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch

        y_pred = self.model(x).reshape(1, -1)
        test_loss = self.loss(y_pred, y)

        self.log("test_loss", test_loss)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return self.optim