Spaces:
Running
Running
from pytorch_lightning.callbacks import Callback | |
class IncreaseDataEpoch(Callback): | |
def __init__(self): | |
super().__init__() | |
def on_train_epoch_start(self, trainer, pl_module): | |
epoch = pl_module.current_epoch | |
if hasattr(trainer.datamodule.train_dataset, "shared_epoch"): | |
trainer.datamodule.train_dataset.shared_epoch.set_value(epoch) | |