--- license: apache-2.0 --- [SFTPM](https://github.com/openvinotoolkit/anomalib/tree/main/anomalib/models/stfpm) model from [Anomalib](https://github.com/openvinotoolkit/anomalib) fine-tuned for capsule category of the [MVTec dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad). Checkpoint trained using the following [notebook](https://github.com/openvinotoolkit/anomalib/blob/main/notebooks/000_getting_started/001_getting_started.ipynb). ``` ────────────────────────────────────────────────── Test metric DataLoader 0 ────────────────────────────────────────────────── image_AUROC 0.8436378240585327 image_F1Score 0.9356223344802856 pixel_AUROC 0.9719913601875305 pixel_F1Score 0.41566985845565796 ────────────────────────────────────────────────── ``` The main intent is to use it in samples and demos for model optimization. Here is the advantages: - MVTec dataset can automatically downloaded and is quite small. - The model from the anomaly detection domain such as SFTPM is sensitive to the optimization methods to allows demonstrate methods with accuracy controll. Here is the code to test the checkpoint: ```python from pytorch_lightning import Trainer from anomalib.config import get_configurable_parameters from anomalib.data import get_datamodule from anomalib.models import get_model from anomalib.utils.callbacks import LoadModelCallback, get_callbacks CHECKPOINT_URL = 'https://huggingface.co/alexsu52/sftpm_mvtec_capsule/resolve/main/pytorch_model.bin' CHECKPOINT_PATH = '~/pytorch_model.bin' #Download CHECKPOINT_URL to CHECKPOINT_PATH config = get_configurable_parameters(config_path="./anomalib/models/sftpm/config.yaml") config["dataset"]["path"] = config['dataset']['category'] = 'capsule' datamodule = get_datamodule(config) datamodule.setup() # Downloads the dataset if it's not in the specified `root` directory datamodule.prepare_data() # Create train/val/test/prediction sets. model = get_model(config) callbacks = get_callbacks(config) load_model_callback = LoadModelCallback(weights_path=CHECKPOINT_PATH) callbacks.insert(0, load_model_callback) trainer = Trainer(**config.trainer, callbacks=callbacks) trainer.test(model=model, datamodule=datamodule) ```