lenet_mnist / src /trainer.py
rzimmerdev's picture
Fixed devices and demo
a0f925f
raw
history blame
No virus
1.14 kB
#!/usr/bin/env python
# coding: utf-8
from torch import nn, optim
import pytorch_lightning as pl
class LitTrainer(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.optim = optim.Adam(self.parameters(), lr=1e-4)
self.loss = nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
x, y = batch
y_pred = self.model(x).reshape(1, -1)
train_loss = self.loss(y_pred, y)
self.log("train_loss", train_loss)
return train_loss
def validation_step(self, batch, batch_idx):
# this is the validation loop
x, y = batch
y_pred = self.model(x).reshape(1, -1)
validate_loss = self.loss(y_pred, y)
self.log("val_loss", validate_loss)
def test_step(self, batch, batch_idx):
# this is the test loop
x, y = batch
y_pred = self.model(x).reshape(1, -1)
test_loss = self.loss(y_pred, y)
self.log("test_loss", test_loss)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
return self.optim