Spaces:
Sleeping
Sleeping
File size: 1,144 Bytes
b358118 2262103 b358118 2262103 b358118 2262103 b358118 2262103 b358118 2262103 b358118 2262103 b358118 2262103 b358118 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
#!/usr/bin/env python
# coding: utf-8
from torch import nn, optim
import pytorch_lightning as pl
class LitTrainer(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.optim = optim.Adam(self.parameters(), lr=1e-4)
self.loss = nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
x, y = batch
y_pred = self.model(x).reshape(1, -1)
train_loss = self.loss(y_pred, y)
self.log("train_loss", train_loss)
return train_loss
def validation_step(self, batch, batch_idx):
# this is the validation loop
x, y = batch
y_pred = self.model(x).reshape(1, -1)
validate_loss = self.loss(y_pred, y)
self.log("val_loss", validate_loss)
def test_step(self, batch, batch_idx):
# this is the test loop
x, y = batch
y_pred = self.model(x).reshape(1, -1)
test_loss = self.loss(y_pred, y)
self.log("test_loss", test_loss)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
return self.optim
|