realfake / realfake /train_cluster.py
devforfu
Init
ea847ad
import warnings
import pytorch_lightning as pl
from realfake.config import SEED
from realfake.models import RealFakeClassifier, RealFakeDataModule, RealFakeParams
from realfake.train import prepare_trainer
def main() -> None:
pl.seed_everything(SEED)
args = RealFakeParams.from_args()
model = RealFakeClassifier(args)
data = RealFakeDataModule(args.jsonl_file, args.batch_size, args.accelerator.devices * 4)
trainer = prepare_trainer(args)
if args.dry_run:
print("Dry run, skipping training.")
print("Model summary:")
print(model)
print("Data summary:")
data.setup()
print("Train batches:", len(data.dls[0]))
print("Valid batches:", len(data.dls[1]))
else:
trainer.fit(model, datamodule=data)
if __name__ == "__main__":
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=Warning)
main()