File size: 4,582 Bytes
2f044c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import hydra
import lightning
from hydra.utils import to_absolute_path
from lightning import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict
from torch.utils.data import DataLoader
from relik.reader.data.relik_reader_re_data import RelikREDataset
from relik.reader.lightning_modules.relik_reader_re_pl_module import (
RelikReaderREPLModule,
)
from relik.reader.pytorch_modules.optim import (
AdamWWithWarmupOptimizer,
LayerWiseLRDecayOptimizer,
)
from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback
from relik.reader.utils.special_symbols import (
get_special_symbols,
get_special_symbols_re,
)
@hydra.main(config_path="../conf", config_name="config_cie")
def train(cfg: DictConfig) -> None:
lightning.seed_everything(cfg.training.seed)
special_symbols = get_special_symbols_re(cfg.model.relations_per_forward)
special_symbols_types = get_special_symbols(cfg.model.entities_per_forward)
# datasets declaration
train_dataset: RelikREDataset = hydra.utils.instantiate(
cfg.data.train_dataset,
dataset_path=to_absolute_path(cfg.data.train_dataset_path),
special_symbols=special_symbols,
special_symbols_types=special_symbols_types,
)
# update of validation dataset config with special_symbols since they
# are required even from the EvaluationCallback dataset_config
with open_dict(cfg):
cfg.data.val_dataset.special_symbols = special_symbols
cfg.data.val_dataset.special_symbols_types = special_symbols_types
val_dataset: RelikREDataset = hydra.utils.instantiate(
cfg.data.val_dataset,
dataset_path=to_absolute_path(cfg.data.val_dataset_path),
)
if val_dataset.materialize_samples:
list(val_dataset.dataset_iterator_func())
# model declaration
model = RelikReaderREPLModule(
cfg=OmegaConf.to_container(cfg),
transformer_model=cfg.model.model.transformer_model,
additional_special_symbols=len(special_symbols),
additional_special_symbols_types=len(special_symbols_types),
entity_type_loss=True,
add_entity_embedding=True,
training=True,
)
model.relik_reader_re_model._tokenizer = train_dataset.tokenizer
# optimizer declaration
opt_conf = cfg.model.optimizer
if "total_reset" not in opt_conf:
optimizer_factory = AdamWWithWarmupOptimizer(
lr=opt_conf.lr,
warmup_steps=opt_conf.warmup_steps,
total_steps=opt_conf.total_steps,
no_decay_params=opt_conf.no_decay_params,
weight_decay=opt_conf.weight_decay,
)
else:
optimizer_factory = LayerWiseLRDecayOptimizer(
lr=opt_conf.lr,
warmup_steps=opt_conf.warmup_steps,
total_steps=opt_conf.total_steps,
total_reset=opt_conf.total_reset,
no_decay_params=opt_conf.no_decay_params,
weight_decay=opt_conf.weight_decay,
lr_decay=opt_conf.lr_decay,
)
model.set_optimizer_factory(optimizer_factory)
# callbacks declaration
callbacks = [
REStrongMatchingCallback(
to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset
),
ModelCheckpoint(
"model",
filename="{epoch}-{val_f1:.2f}",
monitor="val_f1",
mode="max",
),
LearningRateMonitor(),
]
wandb_logger = WandbLogger(
cfg.model_name, project=cfg.project_name
) # , offline=True)
# trainer declaration
trainer: Trainer = hydra.utils.instantiate(
cfg.training.trainer,
callbacks=callbacks,
logger=wandb_logger,
)
# Trainer fit
trainer.fit(
model=model,
train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0),
val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0),
ckpt_path=cfg.training.ckpt_path if cfg.training.ckpt_path else None,
)
# Load best checkpoint
if cfg.training.save_model_path:
model = RelikReaderREPLModule.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path
)
model.relik_reader_re_model._tokenizer = train_dataset.tokenizer
model.relik_reader_re_model.save_pretrained(cfg.training.save_model_path)
def main():
train()
if __name__ == "__main__":
main()
|