Spaces:
Sleeping
Sleeping
import hydra | |
import lightning | |
from hydra.utils import to_absolute_path | |
from lightning import Trainer | |
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint | |
from lightning.pytorch.loggers.wandb import WandbLogger | |
from omegaconf import DictConfig, OmegaConf, open_dict | |
from reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer | |
from torch.utils.data import DataLoader | |
from relik.reader.lightning_modules.relik_reader_re_pl_module import ( | |
RelikReaderREPLModule, | |
) | |
from relik.reader.relik_reader_re_data import RelikREDataset | |
from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback | |
from relik.reader.utils.special_symbols import get_special_symbols_re | |
def train(cfg: DictConfig) -> None: | |
lightning.seed_everything(cfg.training.seed) | |
special_symbols = get_special_symbols_re(cfg.model.entities_per_forward) | |
# datasets declaration | |
train_dataset: RelikREDataset = hydra.utils.instantiate( | |
cfg.data.train_dataset, | |
dataset_path=to_absolute_path(cfg.data.train_dataset_path), | |
special_symbols=special_symbols, | |
) | |
# update of validation dataset config with special_symbols since they | |
# are required even from the EvaluationCallback dataset_config | |
with open_dict(cfg): | |
cfg.data.val_dataset.special_symbols = special_symbols | |
val_dataset: RelikREDataset = hydra.utils.instantiate( | |
cfg.data.val_dataset, | |
dataset_path=to_absolute_path(cfg.data.val_dataset_path), | |
) | |
# model declaration | |
model = RelikReaderREPLModule( | |
cfg=OmegaConf.to_container(cfg), | |
transformer_model=cfg.model.model.transformer_model, | |
additional_special_symbols=len(special_symbols), | |
training=True, | |
) | |
model.relik_reader_re_model._tokenizer = train_dataset.tokenizer | |
# optimizer declaration | |
opt_conf = cfg.model.optimizer | |
# adamw_optimizer_factory = AdamWWithWarmupOptimizer( | |
# lr=opt_conf.lr, | |
# warmup_steps=opt_conf.warmup_steps, | |
# total_steps=opt_conf.total_steps, | |
# no_decay_params=opt_conf.no_decay_params, | |
# weight_decay=opt_conf.weight_decay, | |
# ) | |
electra_optimizer_factory = LayerWiseLRDecayOptimizer( | |
lr=opt_conf.lr, | |
warmup_steps=opt_conf.warmup_steps, | |
total_steps=opt_conf.total_steps, | |
total_reset=opt_conf.total_reset, | |
no_decay_params=opt_conf.no_decay_params, | |
weight_decay=opt_conf.weight_decay, | |
lr_decay=opt_conf.lr_decay, | |
) | |
model.set_optimizer_factory(electra_optimizer_factory) | |
# callbacks declaration | |
callbacks = [ | |
REStrongMatchingCallback( | |
to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset | |
), | |
ModelCheckpoint( | |
"model", | |
filename="{epoch}-{val_f1:.2f}", | |
monitor="val_f1", | |
mode="max", | |
), | |
LearningRateMonitor(), | |
] | |
wandb_logger = WandbLogger(cfg.model_name, project=cfg.project_name) | |
# trainer declaration | |
trainer: Trainer = hydra.utils.instantiate( | |
cfg.training.trainer, | |
callbacks=callbacks, | |
logger=wandb_logger, | |
) | |
# Trainer fit | |
trainer.fit( | |
model=model, | |
train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0), | |
val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0), | |
) | |
def main(): | |
train() | |
if __name__ == "__main__": | |
main() | |