ongkn's picture
Update app.py
04247ac
raw
history blame
No virus
2.97 kB
import gradio as gr
from transformers import pipeline, ViTForImageClassification, ViTImageProcessor
import numpy as np
from PIL import Image
import warnings
import logging
from pytorch_grad_cam import run_dff_on_image, GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import torch
from face_grab import FaceGrabber
from gradcam import GradCam
from torchvision import transforms
logging.basicConfig(level=logging.INFO)
model = ViTForImageClassification.from_pretrained("ongkn/emikes-classifier")
processor = ViTImageProcessor.from_pretrained("ongkn/emikes-classifier")
faceGrabber = FaceGrabber()
gradCam = GradCam()
targetsForGradCam = [ClassifierOutputTarget(gradCam.category_name_to_index(model, "emi")),
ClassifierOutputTarget(gradCam.category_name_to_index(model, "kes"))]
targetLayerDff = model.vit.layernorm
targetLayerGradCam = model.vit.encoder.layer[-2].output
def classify_image(input):
face = faceGrabber.grab_faces(np.array(input))
if face is None:
return "No face detected", 0, input
face = Image.fromarray(face)
faceResized = face.resize((224, 224))
tensorResized = transforms.ToTensor()(faceResized)
dffImage = run_dff_on_image(model=model,
target_layer=targetLayerDff,
classifier=model.classifier,
img_pil=faceResized,
img_tensor=tensorResized,
reshape_transform=gradCam.reshape_transform_vit_huggingface,
n_components=5,
top_k=10
)
result = gradCam.get_top_category(model, tensorResized)
cls = result[0]["label"]
clsIdx = gradCam.category_name_to_index(model, cls)
clsTarget = ClassifierOutputTarget(clsIdx)
gradCamImage = gradCam.run_grad_cam_on_image(model=model,
target_layer=targetLayerGradCam,
targets_for_gradcam=[clsTarget],
input_tensor=tensorResized,
input_image=faceResized,
reshape_transform=gradCam.reshape_transform_vit_huggingface)
return result[0]["label"], result[0]["score"], face, dffImage, gradCamImage
iface = gr.Interface(
fn=classify_image,
inputs="image",
outputs=["text", "number", "image", "image", "image"],
title="Emikes Classifier",
description=f"Takes in a (224, 224) image and outputs a class: {'emi', 'kes'}, along with a GradCam/DFF explanation. Face detection, cropping, and resizing are done internally. Uploaded images are not stored by us, but may be stored by HF. Refer to their [privacy policy](https://huggingface.co/privacy) for details."
)
iface.launch()