Spaces:
Runtime error
Runtime error
File size: 2,701 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from datamodules import CIFAR10QADataModule, ImageDataModule
from datamodules.utils import datamodule_factory
from models import ImageClassificationNet
from models.utils import model_factory
from pytorch_lightning.loggers import WandbLogger
import argparse
import pytorch_lightning as pl
def main(args: argparse.Namespace):
# Seed
pl.seed_everything(args.seed)
# Create base model
base = model_factory(args, own_config=True)
# Load datamodule
dm = datamodule_factory(args)
# Load the model from the specified checkpoint
model = ImageClassificationNet.load_from_checkpoint(
args.checkpoint,
model=base,
num_train_steps=0,
)
# Create wandb logger
wandb_logger = WandbLogger(
name=f"{args.dataset}_eval_{args.base_model} ({args.from_pretrained})",
project="Patch-DiffMask",
)
# Create trainer
trainer = pl.Trainer(
accelerator="auto",
logger=wandb_logger,
max_epochs=1,
enable_progress_bar=args.enable_progress_bar,
)
# Evaluate the model
trainer.test(model, dm)
# Save the HuggingFace model to be used with --from_pretrained
save_dir = f"checkpoints/{args.base_model}_{args.dataset}"
model.model.save_pretrained(save_dir)
dm.feature_extractor.save_pretrained(save_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Checkpoint to resume the training from.",
)
# Trainer
parser.add_argument(
"--enable_progress_bar",
action="store_true",
help="Whether to show progress bar during training. NOT recommended when logging to files.",
)
parser.add_argument(
"--seed",
type=int,
default=123,
help="Random seed for reproducibility.",
)
# Base (classification) model
parser.add_argument(
"--base_model",
type=str,
default="ViT",
choices=["ViT", "ConvNeXt"],
help="Base model architecture to train.",
)
parser.add_argument(
"--from_pretrained",
type=str,
# default="tanlq/vit-base-patch16-224-in21k-finetuned-cifar10",
help="The name of the pretrained HF model to fine-tune from.",
)
# Datamodule
ImageDataModule.add_model_specific_args(parser)
CIFAR10QADataModule.add_model_specific_args(parser)
parser.add_argument(
"--dataset",
type=str,
default="toy",
choices=["MNIST", "CIFAR10", "CIFAR10_QA", "toy"],
help="The dataset to use.",
)
args = parser.parse_args()
main(args)
|