ood-detection / app.py
edadaltocg's picture
implement app
9a960ac
raw
history blame
3.61 kB
"""
Gradio demo of image classification with OOD detection.
If the image example is probably OOD, the model will abstain from the prediction.
"""
import os
import pickle
import json
from glob import glob
import gradio as gr
from gradio.components import Image, Label, JSON
import numpy as np
import torch
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
import logging
_logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
TOPK = 3
# load model
print("Loading model...")
model = timm.create_model("resnet50", pretrained=True)
model.to(device)
model.eval()
# dataset labels
idx2label = json.loads(open("ilsvrc2012.json").read())
idx2label = {int(k): v for k, v in idx2label.items()}
print(idx2label)
# transformation
config = resolve_data_config({}, model=model)
config["is_training"] = False
transform = create_transform(**config)
# print features names
print(get_graph_node_names(model)[0])
# load train scores
penultimate_features_key = "global_pool.flatten"
logits_key = "fc"
features_names = [penultimate_features_key, logits_key]
# create feature extractor
feature_extractor = create_feature_extractor(model, features_names)
# OOD dtector thresholds
msp_threshold = 0.3796
energy_threshold = 0.3781
## unpickle detectors
def mahalanobis_penult(features):
scores = torch.norm(features, dim=1, keepdims=True)
s = torch.min(scores, dim=1)[0]
return -s.item()
def msp(logits):
return torch.softmax(logits, dim=1).max(-1)[0].item()
def energy(logits):
return torch.logsumexp(logits, dim=1).item()
def predict(image):
# forward pass
inputs = transform(image).unsqueeze(0)
with torch.no_grad():
features = feature_extractor(inputs)
# top 5 predictions
probabilities = torch.softmax(features[logits_key], dim=-1)
softmax, class_idxs = torch.topk(probabilities, TOPK)
_logger.info(softmax)
_logger.info(class_idxs)
result = {idx2label[i.item()]: v.item() for i, v in zip(class_idxs.squeeze(), softmax.squeeze())}
# OOD
msp_score = msp(features[logits_key])
energy_score = energy(features[logits_key])
ood_scores = {
"msp": msp_score,
"msp_is_ood": msp_score < msp_threshold,
"energy": energy_score,
"energy_is_ood": energy_score < energy_threshold,
}
_logger.info(ood_scores)
return result, ood_scores
def main():
# image examples for demo shuffled
examples = glob("images/imagenet/*.jpg") + glob("images/ood/*.jpg")
np.random.seed(42)
np.random.shuffle(examples)
# gradio interface
interface = gr.Interface(
fn=predict,
inputs=Image(type="pil"),
outputs=[
Label(num_top_classes=TOPK, label="Model prediction"),
JSON(label="OOD scores"),
],
examples=examples,
examples_per_page=len(examples),
allow_flagging="never",
theme="default",
title="OOD Detection 🧐",
description="Out-of-distribution (OOD) detection is an essential safety measure for machine learning models. This app demonstrates how these methods can be useful in determining wether the inputs of a ResNet-50 model trained on ImageNet-1K can be trusted by the model. Enjoy the demo!",
)
interface.launch(
server_port=7860,
)
interface.close()
if __name__ == "__main__":
logging.basicConfig(level=logging.WARN)
gr.close_all()
main()