vision-diffmask / code /eval_base.py
din0s's picture
Add code
d4ab5ac unverified
from datamodules import CIFAR10QADataModule, ImageDataModule
from datamodules.utils import datamodule_factory
from models import ImageClassificationNet
from models.utils import model_factory
from pytorch_lightning.loggers import WandbLogger
import argparse
import pytorch_lightning as pl
def main(args: argparse.Namespace):
# Seed
pl.seed_everything(args.seed)
# Create base model
base = model_factory(args, own_config=True)
# Load datamodule
dm = datamodule_factory(args)
# Load the model from the specified checkpoint
model = ImageClassificationNet.load_from_checkpoint(
args.checkpoint,
model=base,
num_train_steps=0,
)
# Create wandb logger
wandb_logger = WandbLogger(
name=f"{args.dataset}_eval_{args.base_model} ({args.from_pretrained})",
project="Patch-DiffMask",
)
# Create trainer
trainer = pl.Trainer(
accelerator="auto",
logger=wandb_logger,
max_epochs=1,
enable_progress_bar=args.enable_progress_bar,
)
# Evaluate the model
trainer.test(model, dm)
# Save the HuggingFace model to be used with --from_pretrained
save_dir = f"checkpoints/{args.base_model}_{args.dataset}"
model.model.save_pretrained(save_dir)
dm.feature_extractor.save_pretrained(save_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Checkpoint to resume the training from.",
)
# Trainer
parser.add_argument(
"--enable_progress_bar",
action="store_true",
help="Whether to show progress bar during training. NOT recommended when logging to files.",
)
parser.add_argument(
"--seed",
type=int,
default=123,
help="Random seed for reproducibility.",
)
# Base (classification) model
parser.add_argument(
"--base_model",
type=str,
default="ViT",
choices=["ViT", "ConvNeXt"],
help="Base model architecture to train.",
)
parser.add_argument(
"--from_pretrained",
type=str,
# default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
help="The name of the pretrained HF model to fine-tune from.",
)
# Datamodule
ImageDataModule.add_model_specific_args(parser)
CIFAR10QADataModule.add_model_specific_args(parser)
parser.add_argument(
"--dataset",
type=str,
default="toy",
choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
help="The dataset to use.",
)
args = parser.parse_args()
main(args)