|
import warnings |
|
|
|
import pytorch_lightning as pl |
|
|
|
from realfake.config import SEED |
|
from realfake.models import RealFakeClassifier, RealFakeDataModule, RealFakeParams |
|
from realfake.train import prepare_trainer |
|
|
|
|
|
def main() -> None: |
|
pl.seed_everything(SEED) |
|
args = RealFakeParams.from_args() |
|
model = RealFakeClassifier(args) |
|
data = RealFakeDataModule(args.jsonl_file, args.batch_size, args.accelerator.devices * 4) |
|
trainer = prepare_trainer(args) |
|
|
|
if args.dry_run: |
|
print("Dry run, skipping training.") |
|
print("Model summary:") |
|
print(model) |
|
print("Data summary:") |
|
data.setup() |
|
print("Train batches:", len(data.dls[0])) |
|
print("Valid batches:", len(data.dls[1])) |
|
|
|
else: |
|
trainer.fit(model, datamodule=data) |
|
|
|
|
|
if __name__ == "__main__": |
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings("ignore", category=Warning) |
|
main() |
|
|