DimaKoshman's picture
fix
6a43216
import os
import pandas as pd
import pytorch_lightning as pl
import transformers
import wandb
from config import CONFIG
from data import (
get_annotation_ground_truth_str_from_image_index,
load_train_image_ids,
build_dataloader,
Split,
Batch,
)
from metrics import benetech_score_string_prediction
from model import generate_token_strings, LightningModule
from utils import set_tokenizers_parallelism, set_torch_device_order_pci_bus
class MetricsCallback(pl.callbacks.Callback):
def on_validation_batch_start(
self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0
):
predicted_strings = generate_token_strings(pl_module.model, images=batch.images)
for expected_data_index, predicted_string in zip(
batch.data_indices, predicted_strings, strict=True
):
benetech_score = benetech_score_string_prediction(
expected_data_index=expected_data_index,
predicted_string=predicted_string,
)
wandb.log(dict(benetech_score=benetech_score))
ground_truth_strings = [
get_annotation_ground_truth_str_from_image_index(i)
for i in batch.data_indices
]
string_ids = [load_train_image_ids()[i] for i in batch.data_indices]
strings_dataframe = pd.DataFrame(
dict(
string_ids=string_ids,
ground_truth=ground_truth_strings,
predicted=predicted_strings,
)
)
wandb.log(dict(strings=wandb.Table(dataframe=strings_dataframe)))
class TransformersPreTrainedModelsCheckpointIO(pl.plugins.CheckpointIO):
def __init__(
self, pretrained_models: list[transformers.modeling_utils.PreTrainedModel]
):
super().__init__()
self.pretrained_models = pretrained_models
def save_checkpoint(self, checkpoint, path, storage_options=None):
for pretrained_model in self.pretrained_models:
pretrained_model.save_pretrained(path)
def load_checkpoint(self, path, storage_options=None):
self.pretrained_models = [
pm.from_pretrained(path) for pm in self.pretrained_models
]
def remove_checkpoint(self, path):
os.remove(path)
def train():
set_tokenizers_parallelism(False)
set_torch_device_order_pci_bus()
pl_module = LightningModule(CONFIG)
model_checkpoint = pl.callbacks.ModelCheckpoint(
dirpath=CONFIG.training_directory,
monitor="val_loss",
save_top_k=CONFIG.save_top_k_checkpoints,
)
metrics_callback = MetricsCallback()
logger = pl.loggers.WandbLogger(
project=CONFIG.wandb_project_name, save_dir=CONFIG.training_directory
)
plugin = TransformersPreTrainedModelsCheckpointIO(
[pl_module.model.processor, pl_module.model.encoder_decoder]
)
trainer = pl.Trainer(
accelerator=CONFIG.accelerator,
devices=CONFIG.devices,
plugins=[plugin],
callbacks=[model_checkpoint, metrics_callback],
logger=logger,
limit_train_batches=CONFIG.limit_train_batches,
limit_val_batches=CONFIG.limit_val_batches,
)
trainer.fit(
model=pl_module,
train_dataloaders=build_dataloader(
Split.train, pl_module.model.batch_collate_function
),
val_dataloaders=build_dataloader(
Split.val, pl_module.model.batch_collate_function
),
)
if __name__ == "__main__":
train()