File size: 3,732 Bytes
bd0d978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6fe3c5
 
 
 
 
bd0d978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import pytorch_lightning as pl
import torch
import torch.nn as nn

from src import config
from src import loss as loss_utils
from src import metrics
from src import models


class LightningModule(pl.LightningModule):
    def __init__(
        self,
        vision_encoder: models.TinyCLIPVisionEncoder,
        text_encoder: models.TinyCLIPTextEncoder,
        loss_fn: nn.Module,
        hyper_parameters: config.TrainerConfig,
        len_train_dl: int,
    ) -> None:
        super().__init__()
        self.vision_encoder = vision_encoder
        self.text_encoder = text_encoder
        self.loss_fn = loss_fn
        self.hyper_parameters = hyper_parameters
        self.len_train_dl = len_train_dl

    def common_step(self, batch: dict[str, torch.Tensor], step_kind: str) -> torch.Tensor:
        image_features = self.vision_encoder(batch["images"])
        text_features = self.text_encoder(
            {key: value for key, value in batch.items() if key != "images"}
        )
        similarity_matrix = loss_utils.get_similarity_matrix(image_features, text_features)

        loss = self.loss_fn(similarity_matrix, image_features, text_features)

        img_acc, cap_acc = metrics.metrics(similarity_matrix)

        self.log(f"{step_kind}_loss", loss, on_step=False, on_epoch=True)
        self.log(f"{step_kind}_img_acc", img_acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log(f"{step_kind}_cap_acc", cap_acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def training_step(self, batch: tuple[torch.Tensor, list[str]], *args: list) -> torch.Tensor:
        loss = self.common_step(batch, step_kind="training")
        return loss

    def validation_step(self, batch: tuple[torch.Tensor, list[str]], *args: list):
        _ = self.common_step(batch, step_kind="training")

    def configure_optimizers(self):
        vision_params = [
            {
                "params": self.vision_encoder.projection.parameters(),
                "lr": self.hyper_parameters.learning_rate,
            },
        ]
        caption_params = [
            {
                "params": self.text_encoder.projection.parameters(),
                "lr": self.hyper_parameters.learning_rate,
            },
        ]
        loss_params = [
            {
                "params": self.loss_fn.parameters(),
                "lr": self.hyper_parameters.learning_rate,
            },
        ]

        if not self.hyper_parameters._model_config.freeze_text_base:
            caption_params += [
                {
                    "params": self.text_encoder.base.parameters(),
                    "lr": self.hyper_parameters.learning_rate / 2,
                },
            ]

        if not self.hyper_parameters._model_config.freeze_vision_base:
            vision_params += [
                {
                    "params": self.vision_encoder.base.parameters(),
                    "lr": self.hyper_parameters.learning_rate / 2,
                },
            ]

        optimizer = torch.optim.Adam(
            vision_params + caption_params + loss_params, lr=self.hyper_parameters.learning_rate
        )

        if self.hyper_parameters.lr_scheduler:
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=self.hyper_parameters.learning_rate,
                total_steps=int(self.trainer.estimated_stepping_batches),
            )
            return [optimizer], [scheduler]
        else:
            return optimizer

    def on_epoch_end(self):
        if self.current_epoch == 0:
            for p in self.vision_encoder.base.parameters():
                p.requires_grad = True
            self.vision_encoder.base.train()