import warnings import pytorch_lightning as pl from realfake.config import SEED from realfake.models import RealFakeClassifier, RealFakeDataModule, RealFakeParams from realfake.train import prepare_trainer def main() -> None: pl.seed_everything(SEED) args = RealFakeParams.from_args() model = RealFakeClassifier(args) data = RealFakeDataModule(args.jsonl_file, args.batch_size, args.accelerator.devices * 4) trainer = prepare_trainer(args) if args.dry_run: print("Dry run, skipping training.") print("Model summary:") print(model) print("Data summary:") data.setup() print("Train batches:", len(data.dls[0])) print("Valid batches:", len(data.dls[1])) else: trainer.fit(model, datamodule=data) if __name__ == "__main__": with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=Warning) main()