Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
from torch import nn, optim | |
import pytorch_lightning as pl | |
class LitTrainer(pl.LightningModule): | |
def __init__(self, model): | |
super().__init__() | |
self.model = model | |
self.optim = optim.Adam(self.parameters(), lr=1e-4) | |
self.loss = nn.CrossEntropyLoss() | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
y_pred = self.model(x).reshape(1, -1) | |
train_loss = self.loss(y_pred, y) | |
self.log("train_loss", train_loss) | |
return train_loss | |
def validation_step(self, batch, batch_idx): | |
# this is the validation loop | |
x, y = batch | |
y_pred = self.model(x).reshape(1, -1) | |
validate_loss = self.loss(y_pred, y) | |
self.log("val_loss", validate_loss) | |
def test_step(self, batch, batch_idx): | |
# this is the test loop | |
x, y = batch | |
y_pred = self.model(x).reshape(1, -1) | |
test_loss = self.loss(y_pred, y) | |
self.log("test_loss", test_loss) | |
def forward(self, x): | |
return self.model(x) | |
def configure_optimizers(self): | |
return self.optim | |