File size: 3,532 Bytes
028951c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a43216
028951c
 
 
 
6a43216
028951c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a43216
028951c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os

import pandas as pd
import pytorch_lightning as pl
import transformers
import wandb

from config import CONFIG
from data import (
    get_annotation_ground_truth_str_from_image_index,
    load_train_image_ids,
    build_dataloader,
    Split,
    Batch,
)
from metrics import benetech_score_string_prediction
from model import generate_token_strings, LightningModule
from utils import set_tokenizers_parallelism, set_torch_device_order_pci_bus


class MetricsCallback(pl.callbacks.Callback):
    def on_validation_batch_start(
        self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0
    ):
        predicted_strings = generate_token_strings(pl_module.model, images=batch.images)

        for expected_data_index, predicted_string in zip(
            batch.data_indices, predicted_strings, strict=True
        ):
            benetech_score = benetech_score_string_prediction(
                expected_data_index=expected_data_index,
                predicted_string=predicted_string,
            )
            wandb.log(dict(benetech_score=benetech_score))

        ground_truth_strings = [
            get_annotation_ground_truth_str_from_image_index(i)
            for i in batch.data_indices
        ]
        string_ids = [load_train_image_ids()[i] for i in batch.data_indices]
        strings_dataframe = pd.DataFrame(
            dict(
                string_ids=string_ids,
                ground_truth=ground_truth_strings,
                predicted=predicted_strings,
            )
        )
        wandb.log(dict(strings=wandb.Table(dataframe=strings_dataframe)))


class TransformersPreTrainedModelsCheckpointIO(pl.plugins.CheckpointIO):
    def __init__(
        self, pretrained_models: list[transformers.modeling_utils.PreTrainedModel]
    ):
        super().__init__()
        self.pretrained_models = pretrained_models

    def save_checkpoint(self, checkpoint, path, storage_options=None):
        for pretrained_model in self.pretrained_models:
            pretrained_model.save_pretrained(path)

    def load_checkpoint(self, path, storage_options=None):
        self.pretrained_models = [
            pm.from_pretrained(path) for pm in self.pretrained_models
        ]

    def remove_checkpoint(self, path):
        os.remove(path)


def train():
    set_tokenizers_parallelism(False)
    set_torch_device_order_pci_bus()

    pl_module = LightningModule(CONFIG)

    model_checkpoint = pl.callbacks.ModelCheckpoint(
        dirpath=CONFIG.training_directory,
        monitor="val_loss",
        save_top_k=CONFIG.save_top_k_checkpoints,
    )
    metrics_callback = MetricsCallback()

    logger = pl.loggers.WandbLogger(
        project=CONFIG.wandb_project_name, save_dir=CONFIG.training_directory
    )

    plugin = TransformersPreTrainedModelsCheckpointIO(
        [pl_module.model.processor, pl_module.model.encoder_decoder]
    )

    trainer = pl.Trainer(
        accelerator=CONFIG.accelerator,
        devices=CONFIG.devices,
        plugins=[plugin],
        callbacks=[model_checkpoint, metrics_callback],
        logger=logger,
        limit_train_batches=CONFIG.limit_train_batches,
        limit_val_batches=CONFIG.limit_val_batches,
    )

    trainer.fit(
        model=pl_module,
        train_dataloaders=build_dataloader(
            Split.train, pl_module.model.batch_collate_function
        ),
        val_dataloaders=build_dataloader(
            Split.val, pl_module.model.batch_collate_function
        ),
    )


if __name__ == "__main__":
    train()