File size: 3,525 Bytes
626eca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import hydra
import lightning
from hydra.utils import to_absolute_path
from lightning import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict
from reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer
from torch.utils.data import DataLoader

from relik.reader.lightning_modules.relik_reader_re_pl_module import (
    RelikReaderREPLModule,
)
from relik.reader.relik_reader_re_data import RelikREDataset
from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback
from relik.reader.utils.special_symbols import get_special_symbols_re


@hydra.main(config_path="conf", config_name="config")
def train(cfg: DictConfig) -> None:
    lightning.seed_everything(cfg.training.seed)

    special_symbols = get_special_symbols_re(cfg.model.entities_per_forward)

    # datasets declaration
    train_dataset: RelikREDataset = hydra.utils.instantiate(
        cfg.data.train_dataset,
        dataset_path=to_absolute_path(cfg.data.train_dataset_path),
        special_symbols=special_symbols,
    )

    # update of validation dataset config with special_symbols since they
    #  are required even from the EvaluationCallback dataset_config
    with open_dict(cfg):
        cfg.data.val_dataset.special_symbols = special_symbols

    val_dataset: RelikREDataset = hydra.utils.instantiate(
        cfg.data.val_dataset,
        dataset_path=to_absolute_path(cfg.data.val_dataset_path),
    )

    # model declaration
    model = RelikReaderREPLModule(
        cfg=OmegaConf.to_container(cfg),
        transformer_model=cfg.model.model.transformer_model,
        additional_special_symbols=len(special_symbols),
        training=True,
    )
    model.relik_reader_re_model._tokenizer = train_dataset.tokenizer
    # optimizer declaration
    opt_conf = cfg.model.optimizer

    # adamw_optimizer_factory = AdamWWithWarmupOptimizer(
    #     lr=opt_conf.lr,
    #     warmup_steps=opt_conf.warmup_steps,
    #     total_steps=opt_conf.total_steps,
    #     no_decay_params=opt_conf.no_decay_params,
    #     weight_decay=opt_conf.weight_decay,
    # )

    electra_optimizer_factory = LayerWiseLRDecayOptimizer(
        lr=opt_conf.lr,
        warmup_steps=opt_conf.warmup_steps,
        total_steps=opt_conf.total_steps,
        total_reset=opt_conf.total_reset,
        no_decay_params=opt_conf.no_decay_params,
        weight_decay=opt_conf.weight_decay,
        lr_decay=opt_conf.lr_decay,
    )

    model.set_optimizer_factory(electra_optimizer_factory)

    # callbacks declaration
    callbacks = [
        REStrongMatchingCallback(
            to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset
        ),
        ModelCheckpoint(
            "model",
            filename="{epoch}-{val_f1:.2f}",
            monitor="val_f1",
            mode="max",
        ),
        LearningRateMonitor(),
    ]

    wandb_logger = WandbLogger(cfg.model_name, project=cfg.project_name)

    # trainer declaration
    trainer: Trainer = hydra.utils.instantiate(
        cfg.training.trainer,
        callbacks=callbacks,
        logger=wandb_logger,
    )

    # Trainer fit
    trainer.fit(
        model=model,
        train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0),
        val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0),
    )


def main():
    train()


if __name__ == "__main__":
    main()