rzimmerdev commited on
Commit
2262103
·
1 Parent(s): 1de9461

feature: fixed training loop arguments for Lightning module

Browse files
Files changed (2) hide show
  1. notebooks/trainer.ipynb +56 -1
  2. src/trainer.py +11 -10
notebooks/trainer.ipynb CHANGED
@@ -4,7 +4,62 @@
4
  "cell_type": "code",
5
  "execution_count": null,
6
  "outputs": [],
7
- "source": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  "metadata": {
9
  "collapsed": false,
10
  "pycharm": {
 
4
  "cell_type": "code",
5
  "execution_count": null,
6
  "outputs": [],
7
+ "source": [
8
+ "import torch.optim\n",
9
+ "import pytorch_lightning as pl"
10
+ ],
11
+ "metadata": {
12
+ "collapsed": false,
13
+ "pycharm": {
14
+ "name": "#%%\n"
15
+ }
16
+ }
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "outputs": [],
22
+ "source": [
23
+ "class LitTrainer(pl.LightningModule):\n",
24
+ " def __init__(self, model, loss_fn, optim):\n",
25
+ " super().__init__()\n",
26
+ " self.model = model\n",
27
+ " self.loss_fn = loss_fn\n",
28
+ " self.optim = optim\n",
29
+ "\n",
30
+ " def training_step(self, batch, batch_idx):\n",
31
+ " x, y = batch\n",
32
+ " x = x.to(torch.float32)\n",
33
+ "\n",
34
+ " y_pred = self.model(x).reshape(1, -1)\n",
35
+ " train_loss = self.loss_fn(y_pred, y)\n",
36
+ "\n",
37
+ " self.log(\"train_loss\", train_loss)\n",
38
+ " return train_loss\n",
39
+ "\n",
40
+ " def validation_step(self, batch, batch_idx):\n",
41
+ " # this is the validation loop\n",
42
+ " x, y = batch\n",
43
+ " x = x.to(torch.float32)\n",
44
+ "\n",
45
+ " y_pred = self.model(x).reshape(1, -1)\n",
46
+ " validate_loss = self.loss_fn(y_pred, y)\n",
47
+ "\n",
48
+ " self.log(\"val_loss\", validate_loss)\n",
49
+ "\n",
50
+ " def test_step(self, batch, batch_idx):\n",
51
+ " # this is the test loop\n",
52
+ " x, y = batch\n",
53
+ " x = x.to(torch.float32)\n",
54
+ "\n",
55
+ " y_pred = self.model(x).reshape(1, -1)\n",
56
+ " test_loss = self.loss_fn(y_pred, y)\n",
57
+ "\n",
58
+ " self.log(\"test_loss\", test_loss)\n",
59
+ "\n",
60
+ " def configure_optimizers(self):\n",
61
+ " return self.optim\n"
62
+ ],
63
  "metadata": {
64
  "collapsed": false,
65
  "pycharm": {
src/trainer.py CHANGED
@@ -1,22 +1,22 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
- import torch.optim
 
4
  import pytorch_lightning as pl
5
 
6
 
7
  class LitTrainer(pl.LightningModule):
8
- def __init__(self, model, loss_fn, optim):
9
  super().__init__()
10
  self.model = model
11
- self.loss_fn = loss_fn
12
- self.optim = optim
13
 
14
  def training_step(self, batch, batch_idx):
15
  x, y = batch
16
- x = x.to(torch.float32)
17
 
18
  y_pred = self.model(x).reshape(1, -1)
19
- train_loss = self.loss_fn(y_pred, y)
20
 
21
  self.log("train_loss", train_loss)
22
  return train_loss
@@ -24,22 +24,23 @@ class LitTrainer(pl.LightningModule):
24
  def validation_step(self, batch, batch_idx):
25
  # this is the validation loop
26
  x, y = batch
27
- x = x.to(torch.float32)
28
 
29
  y_pred = self.model(x).reshape(1, -1)
30
- validate_loss = self.loss_fn(y_pred, y)
31
 
32
  self.log("val_loss", validate_loss)
33
 
34
  def test_step(self, batch, batch_idx):
35
  # this is the test loop
36
  x, y = batch
37
- x = x.to(torch.float32)
38
 
39
  y_pred = self.model(x).reshape(1, -1)
40
- test_loss = self.loss_fn(y_pred, y)
41
 
42
  self.log("test_loss", test_loss)
43
 
 
 
 
44
  def configure_optimizers(self):
45
  return self.optim
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
+ import torch
4
+ from torch import nn, optim
5
  import pytorch_lightning as pl
6
 
7
 
8
  class LitTrainer(pl.LightningModule):
9
+ def __init__(self, model):
10
  super().__init__()
11
  self.model = model
12
+ self.optim = optim.Adam(self.parameters(), lr=1e-4)
13
+ self.loss = nn.CrossEntropyLoss()
14
 
15
  def training_step(self, batch, batch_idx):
16
  x, y = batch
 
17
 
18
  y_pred = self.model(x).reshape(1, -1)
19
+ train_loss = self.loss(y_pred, y)
20
 
21
  self.log("train_loss", train_loss)
22
  return train_loss
 
24
  def validation_step(self, batch, batch_idx):
25
  # this is the validation loop
26
  x, y = batch
 
27
 
28
  y_pred = self.model(x).reshape(1, -1)
29
+ validate_loss = self.loss(y_pred, y)
30
 
31
  self.log("val_loss", validate_loss)
32
 
33
  def test_step(self, batch, batch_idx):
34
  # this is the test loop
35
  x, y = batch
 
36
 
37
  y_pred = self.model(x).reshape(1, -1)
38
+ test_loss = self.loss(y_pred, y)
39
 
40
  self.log("test_loss", test_loss)
41
 
42
+ def forward(self, x):
43
+ return self.model(x)
44
+
45
  def configure_optimizers(self):
46
  return self.optim