import gradio as gr from transformers import pipeline, ViTForImageClassification, ViTImageProcessor import numpy as np from PIL import Image import warnings import logging from pytorch_grad_cam import run_dff_on_image, GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image import torch from face_grab import FaceGrabber from gradcam import GradCam from torchvision import transforms logging.basicConfig(level=logging.INFO) model = ViTForImageClassification.from_pretrained("ongkn/emikes-classifier") processor = ViTImageProcessor.from_pretrained("ongkn/emikes-classifier") faceGrabber = FaceGrabber() gradCam = GradCam() targetsForGradCam = [ClassifierOutputTarget(gradCam.category_name_to_index(model, "emi")), ClassifierOutputTarget(gradCam.category_name_to_index(model, "kes"))] targetLayerDff = model.vit.layernorm targetLayerGradCam = model.vit.encoder.layer[-2].output def classify_image(input): face = faceGrabber.grab_faces(np.array(input)) if face is None: return "No face detected", 0, input face = Image.fromarray(face) faceResized = face.resize((224, 224)) tensorResized = transforms.ToTensor()(faceResized) dffImage = run_dff_on_image(model=model, target_layer=targetLayerDff, classifier=model.classifier, img_pil=faceResized, img_tensor=tensorResized, reshape_transform=gradCam.reshape_transform_vit_huggingface, n_components=5, top_k=10 ) result = gradCam.get_top_category(model, tensorResized) cls = result[0]["label"] clsIdx = gradCam.category_name_to_index(model, cls) clsTarget = ClassifierOutputTarget(clsIdx) gradCamImage = gradCam.run_grad_cam_on_image(model=model, target_layer=targetLayerGradCam, targets_for_gradcam=[clsTarget], input_tensor=tensorResized, input_image=faceResized, reshape_transform=gradCam.reshape_transform_vit_huggingface) return result[0]["label"], result[0]["score"], face, dffImage, gradCamImage iface = gr.Interface( fn=classify_image, inputs="image", outputs=["text", "number", "image", "image", "image"], title="Emikes Classifier", description=f"Takes in a (224, 224) image and outputs a class: {'emi', 'kes'}, along with a GradCam/DFF explanation. Face detection, cropping, and resizing are done internally. Uploaded images are not stored by us, but may be stored by HF. Refer to their [privacy policy](https://huggingface.co/privacy) for details." ) iface.launch()