Spaces:
Running
Running
File size: 379 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 |
from pytorch_lightning.callbacks import Callback
class IncreaseDataEpoch(Callback):
def __init__(self):
super().__init__()
def on_train_epoch_start(self, trainer, pl_module):
epoch = pl_module.current_epoch
if hasattr(trainer.datamodule.train_dataset, "shared_epoch"):
trainer.datamodule.train_dataset.shared_epoch.set_value(epoch)
|