File size: 4,582 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import hydra
import lightning
from hydra.utils import to_absolute_path
from lightning import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict
from torch.utils.data import DataLoader

from relik.reader.data.relik_reader_re_data import RelikREDataset
from relik.reader.lightning_modules.relik_reader_re_pl_module import (
    RelikReaderREPLModule,
)
from relik.reader.pytorch_modules.optim import (
    AdamWWithWarmupOptimizer,
    LayerWiseLRDecayOptimizer,
)
from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback
from relik.reader.utils.special_symbols import (
    get_special_symbols,
    get_special_symbols_re,
)


@hydra.main(config_path="../conf", config_name="config_cie")
def train(cfg: DictConfig) -> None:
    lightning.seed_everything(cfg.training.seed)

    special_symbols = get_special_symbols_re(cfg.model.relations_per_forward)
    special_symbols_types = get_special_symbols(cfg.model.entities_per_forward)
    # datasets declaration
    train_dataset: RelikREDataset = hydra.utils.instantiate(
        cfg.data.train_dataset,
        dataset_path=to_absolute_path(cfg.data.train_dataset_path),
        special_symbols=special_symbols,
        special_symbols_types=special_symbols_types,
    )

    # update of validation dataset config with special_symbols since they
    #  are required even from the EvaluationCallback dataset_config
    with open_dict(cfg):
        cfg.data.val_dataset.special_symbols = special_symbols
        cfg.data.val_dataset.special_symbols_types = special_symbols_types

    val_dataset: RelikREDataset = hydra.utils.instantiate(
        cfg.data.val_dataset,
        dataset_path=to_absolute_path(cfg.data.val_dataset_path),
    )

    if val_dataset.materialize_samples:
        list(val_dataset.dataset_iterator_func())

    # model declaration
    model = RelikReaderREPLModule(
        cfg=OmegaConf.to_container(cfg),
        transformer_model=cfg.model.model.transformer_model,
        additional_special_symbols=len(special_symbols),
        additional_special_symbols_types=len(special_symbols_types),
        entity_type_loss=True,
        add_entity_embedding=True,
        training=True,
    )
    model.relik_reader_re_model._tokenizer = train_dataset.tokenizer
    # optimizer declaration
    opt_conf = cfg.model.optimizer

    if "total_reset" not in opt_conf:
        optimizer_factory = AdamWWithWarmupOptimizer(
            lr=opt_conf.lr,
            warmup_steps=opt_conf.warmup_steps,
            total_steps=opt_conf.total_steps,
            no_decay_params=opt_conf.no_decay_params,
            weight_decay=opt_conf.weight_decay,
        )
    else:
        optimizer_factory = LayerWiseLRDecayOptimizer(
            lr=opt_conf.lr,
            warmup_steps=opt_conf.warmup_steps,
            total_steps=opt_conf.total_steps,
            total_reset=opt_conf.total_reset,
            no_decay_params=opt_conf.no_decay_params,
            weight_decay=opt_conf.weight_decay,
            lr_decay=opt_conf.lr_decay,
        )

    model.set_optimizer_factory(optimizer_factory)

    # callbacks declaration
    callbacks = [
        REStrongMatchingCallback(
            to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset
        ),
        ModelCheckpoint(
            "model",
            filename="{epoch}-{val_f1:.2f}",
            monitor="val_f1",
            mode="max",
        ),
        LearningRateMonitor(),
    ]

    wandb_logger = WandbLogger(
        cfg.model_name, project=cfg.project_name
    )  # , offline=True)

    # trainer declaration
    trainer: Trainer = hydra.utils.instantiate(
        cfg.training.trainer,
        callbacks=callbacks,
        logger=wandb_logger,
    )

    # Trainer fit
    trainer.fit(
        model=model,
        train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0),
        val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0),
        ckpt_path=cfg.training.ckpt_path if cfg.training.ckpt_path else None,
    )

    # Load best checkpoint
    if cfg.training.save_model_path:
        model = RelikReaderREPLModule.load_from_checkpoint(
            trainer.checkpoint_callback.best_model_path
        )
        model.relik_reader_re_model._tokenizer = train_dataset.tokenizer
        model.relik_reader_re_model.save_pretrained(cfg.training.save_model_path)


def main():
    train()


if __name__ == "__main__":
    main()