""" Gradio demo of image classification with OOD detection. If the image example is probably OOD, the model will abstain from the prediction. """ import os import pickle import json from glob import glob import gradio as gr from gradio.components import Image, Label, JSON import numpy as np import torch import timm from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names import logging _logger = logging.getLogger(__name__) device = "cuda" if torch.cuda.is_available() else "cpu" TOPK = 3 # load model print("Loading model...") model = timm.create_model("resnet50", pretrained=True) model.to(device) model.eval() # dataset labels idx2label = json.loads(open("ilsvrc2012.json").read()) idx2label = {int(k): v for k, v in idx2label.items()} print(idx2label) # transformation config = resolve_data_config({}, model=model) config["is_training"] = False transform = create_transform(**config) # print features names print(get_graph_node_names(model)[0]) # load train scores penultimate_features_key = "global_pool.flatten" logits_key = "fc" features_names = [penultimate_features_key, logits_key] # create feature extractor feature_extractor = create_feature_extractor(model, features_names) # OOD dtector thresholds msp_threshold = 0.3796 energy_threshold = 0.3781 ## unpickle detectors def mahalanobis_penult(features): scores = torch.norm(features, dim=1, keepdims=True) s = torch.min(scores, dim=1)[0] return -s.item() def msp(logits): return torch.softmax(logits, dim=1).max(-1)[0].item() def energy(logits): return torch.logsumexp(logits, dim=1).item() def predict(image): # forward pass inputs = transform(image).unsqueeze(0) with torch.no_grad(): features = feature_extractor(inputs) # top 5 predictions probabilities = torch.softmax(features[logits_key], dim=-1) softmax, class_idxs = torch.topk(probabilities, TOPK) _logger.info(softmax) _logger.info(class_idxs) result = {idx2label[i.item()]: v.item() for i, v in zip(class_idxs.squeeze(), softmax.squeeze())} # OOD msp_score = msp(features[logits_key]) energy_score = energy(features[logits_key]) ood_scores = { "msp": msp_score, "msp_is_ood": msp_score < msp_threshold, "energy": energy_score, "energy_is_ood": energy_score < energy_threshold, } _logger.info(ood_scores) return result, ood_scores def main(): # image examples for demo shuffled examples = glob("images/imagenet/*.jpg") + glob("images/ood/*.jpg") np.random.seed(42) np.random.shuffle(examples) # gradio interface interface = gr.Interface( fn=predict, inputs=Image(type="pil"), outputs=[ Label(num_top_classes=TOPK, label="Model prediction"), JSON(label="OOD scores"), ], examples=examples, examples_per_page=len(examples), allow_flagging="never", theme="default", title="OOD Detection 🧐", description="Out-of-distribution (OOD) detection is an essential safety measure for machine learning models. This app demonstrates how these methods can be useful. They try to determine wether we can trust the predictions of a ResNet-50 model trained on ImageNet-1K. Enjoy the demo!", ) interface.launch( server_port=7860, ) interface.close() if __name__ == "__main__": logging.basicConfig(level=logging.WARN) gr.close_all() main()