sachin commited on
Commit
bd0d978
1 Parent(s): 6d1b6c6

Corrected configure_optimizer in lightning module

Browse files
Files changed (3) hide show
  1. src/config.py +1 -0
  2. src/lightning_module.py +107 -0
  3. src/trainer.py +0 -91
src/config.py CHANGED
@@ -101,6 +101,7 @@ class TrainerConfig(pydantic.BaseModel):
101
  epochs: int = 20
102
  batch_size: int = 64
103
  learning_rate: float = 5e-4
 
104
  accumulate_grad_batches: int = 1
105
  temperature: float = 1.0
106
  vision_freeze_layers: int = 2
 
101
  epochs: int = 20
102
  batch_size: int = 64
103
  learning_rate: float = 5e-4
104
+ lr_scheduler: bool = True
105
  accumulate_grad_batches: int = 1
106
  temperature: float = 1.0
107
  vision_freeze_layers: int = 2
src/lightning_module.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from src import config
6
+ from src import loss as loss_utils
7
+ from src import metrics
8
+ from src import models
9
+
10
+
11
+ class LightningModule(pl.LightningModule):
12
+ def __init__(
13
+ self,
14
+ vision_encoder: models.TinyCLIPVisionEncoder,
15
+ text_encoder: models.TinyCLIPTextEncoder,
16
+ loss_fn: nn.Module,
17
+ hyper_parameters: config.TrainerConfig,
18
+ len_train_dl: int,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.vision_encoder = vision_encoder
22
+ self.text_encoder = text_encoder
23
+ self.loss_fn = loss_fn
24
+ self.hyper_parameters = hyper_parameters
25
+ self.len_train_dl = len_train_dl
26
+
27
+ def common_step(self, batch: tuple[torch.Tensor, list[str]], step_kind: str) -> torch.Tensor:
28
+ text, images = batch
29
+ image_features = self.vision_encoder(images)
30
+ text_features = self.text_encoder(text)
31
+ similarity_matrix = loss_utils.get_similarity_matrix(image_features, text_features)
32
+
33
+ loss = self.loss_fn(similarity_matrix, image_features, text_features)
34
+
35
+ img_acc, cap_acc = metrics.metrics(similarity_matrix)
36
+
37
+ self.log(f"{step_kind}_loss", loss, on_step=False, on_epoch=True)
38
+ self.log(f"{step_kind}_img_acc", img_acc, on_step=False, on_epoch=True, prog_bar=True)
39
+ self.log(f"{step_kind}_cap_acc", cap_acc, on_step=False, on_epoch=True, prog_bar=True)
40
+ return loss
41
+
42
+ def training_step(self, batch: tuple[torch.Tensor, list[str]], *args: list) -> torch.Tensor:
43
+ loss = self.common_step(batch, step_kind="training")
44
+ return loss
45
+
46
+ def validation_step(self, batch: tuple[torch.Tensor, list[str]], *args: list):
47
+ _ = self.common_step(batch, step_kind="training")
48
+
49
+ def configure_optimizers(self):
50
+ vision_params = [
51
+ {
52
+ "params": self.vision_encoder.projection.parameters(),
53
+ "lr": self.hyper_parameters.learning_rate,
54
+ },
55
+ {
56
+ "params": self.vision_encoder.base.parameters(),
57
+ "lr": self.hyper_parameters.learning_rate / 2,
58
+ },
59
+ ]
60
+ caption_params = [
61
+ {
62
+ "params": self.text_encoder.projection.parameters(),
63
+ "lr": self.hyper_parameters.learning_rate,
64
+ },
65
+ ]
66
+ loss_params = [
67
+ {
68
+ "params": self.loss_fn.parameters(),
69
+ "lr": self.hyper_parameters.learning_rate,
70
+ },
71
+ ]
72
+
73
+ if not self.hyper_parameters._model_config.freeze_text_base:
74
+ caption_params += [
75
+ {
76
+ "params": self.text_encoder.base.parameters(),
77
+ "lr": self.hyper_parameters.learning_rate / 2,
78
+ },
79
+ ]
80
+
81
+ if not self.hyper_parameters._model_config.freeze_vision_base:
82
+ vision_params += [
83
+ {
84
+ "params": self.vision_encoder.base.parameters(),
85
+ "lr": self.hyper_parameters.learning_rate / 2,
86
+ },
87
+ ]
88
+
89
+ optimizer = torch.optim.Adam(
90
+ vision_params + caption_params + loss_params, lr=self.hyper_parameters.learning_rate
91
+ )
92
+
93
+ if self.hyper_parameters.lr_scheduler:
94
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
95
+ optimizer,
96
+ max_lr=self.hyper_parameters.learning_rate,
97
+ total_steps=int(self.trainer.estimated_stepping_batches),
98
+ )
99
+ return [optimizer], [scheduler]
100
+ else:
101
+ return optimizer
102
+
103
+ def on_epoch_end(self):
104
+ if self.current_epoch == 0:
105
+ for p in self.vision_encoder.base.parameters():
106
+ p.requires_grad = True
107
+ self.vision_encoder.base.train()
src/trainer.py CHANGED
@@ -1,91 +0,0 @@
1
- import pytorch_lightning as pl
2
- import torch
3
- import torch.nn as nn
4
-
5
- from src import config
6
- from src import loss as loss_utils
7
- from src import metrics
8
- from src import models
9
-
10
-
11
- class LightningModule(pl.LightningModule):
12
- def __init__(
13
- self,
14
- vision_encoder: models.TinyCLIPVisionEncoder,
15
- text_encoder: models.TinyCLIPTextEncoder,
16
- loss_fn: nn.Module,
17
- hyper_parameters: config.TrainerConfig,
18
- len_train_dl: int,
19
- ) -> None:
20
- super().__init__()
21
- self.vision_encoder = vision_encoder
22
- self.text_encoder = text_encoder
23
- self.loss_fn = loss_fn
24
- self.hyper_parameters = hyper_parameters
25
- self.len_train_dl = len_train_dl
26
-
27
- def common_step(self, batch: tuple[torch.Tensor, list[str]], step_kind: str) -> torch.Tensor:
28
- text, images = batch
29
- image_features = self.vision_encoder(images)
30
- text_features = self.text_encoder(text)
31
- similarity_matrix = loss_utils.get_similarity_matrix(image_features, text_features)
32
-
33
- loss = self.loss_fn(similarity_matrix, image_features, text_features)
34
-
35
- img_acc, cap_acc = metrics.metrics(similarity_matrix)
36
-
37
- self.log(f"{step_kind}_loss", loss, on_step=False, on_epoch=True)
38
- self.log(f"{step_kind}_img_acc", img_acc, on_step=False, on_epoch=True, prog_bar=True)
39
- self.log(f"{step_kind}_cap_acc", cap_acc, on_step=False, on_epoch=True, prog_bar=True)
40
- return loss
41
-
42
- def training_step(self, batch: tuple[torch.Tensor, list[str]], *args: list) -> torch.Tensor:
43
- loss = self.common_step(batch, step_kind="training")
44
- return loss
45
-
46
- def validation_step(self, batch: tuple[torch.Tensor, list[str]], *args: list):
47
- _ = self.common_step(batch, step_kind="training")
48
-
49
- def configure_optimizers(self):
50
- # TODO: Add loss parameters here
51
- vision_params = [
52
- {
53
- "params": self.vision_encoder.projection.parameters(),
54
- "lr": self.hyper_parameters.learning_rate,
55
- },
56
- {
57
- "params": self.vision_encoder.base.parameters(),
58
- "lr": self.hyper_parameters.learning_rate / 2,
59
- },
60
- ]
61
- caption_params = [
62
- {
63
- "params": self.text_encoder.projection.parameters(),
64
- "lr": self.hyper_parameters.learning_rate,
65
- },
66
- ]
67
- if not self.hyper_parameters.freeze_text_base:
68
- caption_params += [
69
- {
70
- "params": self.text_encoder.base.encoder.parameters(),
71
- "lr": self.hyper_parameters.learning_rate / 2,
72
- },
73
- ]
74
-
75
- optimizer = torch.optim.Adam(vision_params + caption_params)
76
-
77
- if self.hyper_parameters.lr_scheduler:
78
- scheduler = torch.optim.lr_scheduler.OneCycleLR(
79
- optimizer,
80
- max_lr=self.hyper_parameters.learning_rate,
81
- total_steps=self.trainer.estimated_stepping_batches,
82
- )
83
- return [optimizer], [scheduler]
84
- else:
85
- return optimizer
86
-
87
- def on_epoch_end(self):
88
- if self.current_epoch == 0:
89
- for p in self.vision_encoder.base.parameters():
90
- p.requires_grad = True
91
- self.vision_encoder.base.train()