import sys import PIL import cv2 import torch import torchvision import torch.nn as nn from utils.save_load import load_model import gradio as gr from PIL import Image from torchvision import transforms import gradio as gr from pytorch_grad_cam import GradCAM, AblationCAM, FullGrad, EigenGradCAM, LayerCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam import DeepFeatureFactorization from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image, deprocess_image import numpy as np from typing import List from matplotlib import pyplot as plt from matplotlib.lines import Line2D labels = [ "Achaemenid architecture", "American craftsman style", "American Foursquare architecture", "Ancient Egyptian architecture", "Art Deco architecture", "Art Nouveau architecture", "Baroque architecture", "Bauhaus architecture", "Beaux-Arts architecture", "Brutalism architecture", "Byzantine architecture", "Chicago school architecture", "Colonial architecture", "Deconstructivism", "Edwardian architecture", "Georgian architecture", "Gothic architecture", "Greek Revival architecture", "International style", "Islamic architecture", "Novelty architecture", "Palladian architecture", "Postmodern architecture", "Queen Anne architecture", "Romanesque architecture", "Russian Revival architecture", "Tudor Revival architecture" ] print(len(labels)) model = torchvision.models.efficientnet_v2_l() model.classifier = nn.Sequential( nn.Dropout(p=0.4, inplace=True), nn.Linear(1280, len(labels), bias=True) ) load_model(model) target_layers = model.features[-1] classifier = model.classifier cam = LayerCAM(model=model, target_layers=target_layers, use_cuda=False) dff = DeepFeatureFactorization( model=model, target_layer=target_layers, computation_on_concepts=classifier) def show_factorization_on_image(img: np.ndarray, explanations: np.ndarray, colors: List[np.ndarray] = None, image_weight: float = 0.5, concept_labels: List = None) -> np.ndarray: n_components = explanations.shape[0] if colors is None: # taken from https://github.com/edocollins/DFF/blob/master/utils.py _cmap = plt.cm.get_cmap('gist_rainbow') colors = [ np.array( _cmap(i)) for i in np.arange( 0, 1, 1.0 / n_components)] concept_per_pixel = explanations.argmax(axis=0) masks = [] for i in range(n_components): mask = np.zeros(shape=(img.shape[0], img.shape[1], 3)) mask[:, :, :] = colors[i][:3] explanation = explanations[i] explanation[concept_per_pixel != i] = 0 mask = np.uint8(mask * 255) mask = cv2.cvtColor(mask, cv2.COLOR_RGB2HSV) mask[:, :, 2] = np.uint8(255 * explanation) mask = cv2.cvtColor(mask, cv2.COLOR_HSV2RGB) mask = np.float32(mask) / 255 masks.append(mask) mask = np.sum(np.float32(masks), axis=0) result = img * image_weight + mask * (1 - image_weight) result = np.uint8(result * 255) if concept_labels is not None: px = 1 / plt.rcParams['figure.dpi'] # pixel in inches fig = plt.figure(figsize=(result.shape[1] * px, result.shape[0] * px)) plt.rcParams['legend.fontsize'] = 6 * result.shape[0] / 256 lw = 5 * result.shape[0] / 256 lines = [Line2D([0], [0], color=colors[i], lw=lw) for i in range(n_components)] plt.legend(lines, concept_labels, fancybox=False, shadow=False, frameon=False, loc="center") plt.tight_layout(pad=0, w_pad=0, h_pad=0) plt.axis('off') fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) plt.close(fig=fig) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) data = cv2.resize(data, (result.shape[1], result.shape[0])) result = np.vstack((result, data)) return result def create_labels(concept_scores, top_k=2): """ Create a list with the image-net category names of the top scoring categories""" concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k] concept_labels_topk = [] for concept_index in range(concept_categories.shape[0]): categories = concept_categories[concept_index, :] concept_labels = [] for category in categories: score = concept_scores[concept_index, category] label = f"{labels[category].split(',')[0]}:{score*100:.2f}%" concept_labels.append(label) concept_labels_topk.append("\n".join(concept_labels)) return concept_labels_topk def predict(rgb_img, top_k): print(top_k) inp_01 = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize([0.4937, 0.5060, 0.5030], [ 0.2705, 0.2653, 0.2998]), transforms.Resize((224, 224)), ])(rgb_img) model.eval() with torch.no_grad(): prediction = torch.nn.functional.softmax( model(inp_01.unsqueeze(0))[0], dim=0) confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))} concepts, batch_explanations, concept_outputs = dff( inp_01.unsqueeze(0), 5) concept_outputs = torch.softmax( torch.from_numpy(concept_outputs), axis=-1).numpy() concept_label_strings = create_labels(concept_outputs, top_k=top_k) print(inp_01.shape) print(batch_explanations[0].shape) res = cv2.resize(np.transpose( batch_explanations[0], (1, 2, 0)), (rgb_img.size[0], rgb_img.size[1])) res = np.transpose(res, (2, 0, 1)) print(res.shape) visualization_01 = show_factorization_on_image(np.float32(rgb_img)/255.0, res, image_weight=0.3, concept_labels=concept_label_strings) return confidences, visualization_01, gr.Interface(fn=predict, inputs=[gr.Image(type="pil"), gr.Slider( minimum=1, maximum=4, label="Number of top results", step=1)], outputs=[gr.Label(num_top_classes=5), "image"], examples=[["./assets/bauhaus.jpg", 1], ["./assets/frank_gehry.jpg", 2], ["./assets/pyramid.jpg", 3]] ).launch() # examples=["./assets/bauhaus.jpg", "./assets/frank_gehry.jpg", "./assets/pyramid.jpg"]