File size: 2,701 Bytes
d4ab5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from datamodules import CIFAR10QADataModule, ImageDataModule
from datamodules.utils import datamodule_factory
from models import ImageClassificationNet
from models.utils import model_factory
from pytorch_lightning.loggers import WandbLogger

import argparse
import pytorch_lightning as pl


def main(args: argparse.Namespace):
    # Seed
    pl.seed_everything(args.seed)

    # Create base model
    base = model_factory(args, own_config=True)

    # Load datamodule
    dm = datamodule_factory(args)

    # Load the model from the specified checkpoint
    model = ImageClassificationNet.load_from_checkpoint(
        args.checkpoint,
        model=base,
        num_train_steps=0,
    )

    # Create wandb logger
    wandb_logger = WandbLogger(
        name=f"{args.dataset}_eval_{args.base_model} ({args.from_pretrained})",
        project="Patch-DiffMask",
    )

    # Create trainer
    trainer = pl.Trainer(
        accelerator="auto",
        logger=wandb_logger,
        max_epochs=1,
        enable_progress_bar=args.enable_progress_bar,
    )

    # Evaluate the model
    trainer.test(model, dm)

    # Save the HuggingFace model to be used with --from_pretrained
    save_dir = f"checkpoints/{args.base_model}_{args.dataset}"
    model.model.save_pretrained(save_dir)
    dm.feature_extractor.save_pretrained(save_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--checkpoint",
        type=str,
        required=True,
        help="Checkpoint to resume the training from.",
    )

    # Trainer
    parser.add_argument(
        "--enable_progress_bar",
        action="store_true",
        help="Whether to show progress bar during training. NOT recommended when logging to files.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=123,
        help="Random seed for reproducibility.",
    )

    # Base (classification) model
    parser.add_argument(
        "--base_model",
        type=str,
        default="ViT",
        choices=["ViT", "ConvNeXt"],
        help="Base model architecture to train.",
    )
    parser.add_argument(
        "--from_pretrained",
        type=str,
        # default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
        help="The name of the pretrained HF model to fine-tune from.",
    )

    # Datamodule
    ImageDataModule.add_model_specific_args(parser)
    CIFAR10QADataModule.add_model_specific_args(parser)
    parser.add_argument(
        "--dataset",
        type=str,
        default="toy",
        choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
        help="The dataset to use.",
    )

    args = parser.parse_args()

    main(args)