File size: 939 Bytes
ea847ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import warnings

import pytorch_lightning as pl

from realfake.config import SEED
from realfake.models import RealFakeClassifier, RealFakeDataModule, RealFakeParams
from realfake.train import prepare_trainer


def main() -> None:
    pl.seed_everything(SEED)
    args = RealFakeParams.from_args()
    model = RealFakeClassifier(args)
    data = RealFakeDataModule(args.jsonl_file, args.batch_size, args.accelerator.devices * 4)
    trainer = prepare_trainer(args)

    if args.dry_run:
        print("Dry run, skipping training.")
        print("Model summary:")
        print(model)
        print("Data summary:")
        data.setup()
        print("Train batches:", len(data.dls[0]))
        print("Valid batches:", len(data.dls[1]))
        
    else:
        trainer.fit(model, datamodule=data)


if __name__ == "__main__":
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=Warning)
        main()