import torch import torchvision from torchvision import transforms import numpy as np import gradio as gr from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from resnet import ResNet18 model = ResNet18() model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False) inv_normalize = transforms.Normalize( mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], std=[1/0.23, 1/0.23, 1/0.23] ) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') def resize_image_pil(image, new_width, new_height): img = Image.fromarray(np.array(image)) width, height = img.size width_scale = new_width / width height_scale = new_height / height scale = min(width_scale, height_scale) resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST) resized = resized.crop((0, 0, new_width, new_height)) return resized def inference(input_img, show_gradcam, num_gradcam, target_layer_number, opacity, show_misclassified, num_misclassified, num_top_classes): input_img = resize_image_pil(input_img, 32, 32) input_img = np.array(input_img) org_img = input_img input_img = input_img.reshape((32, 32, 3)) transform = transforms.ToTensor() input_img = transform(input_img) input_img = input_img.unsqueeze(0) outputs = model(input_img) softmax = torch.nn.Softmax(dim=1) probs = softmax(outputs) top_probs, top_labels = torch.topk(probs, k=min(num_top_classes, 10)) top_classes = [classes[idx] for idx in top_labels[0]] confidences = {cls: float(prob) for cls, prob in zip(top_classes, top_probs[0])} _, prediction = torch.max(outputs, 1) predicted_class = classes[prediction[0].item()] results = [predicted_class, confidences] if show_gradcam: target_layers = [model.layer2[target_layer_number]] cam = GradCAM(model=model, target_layers=target_layers) grayscale_cam = cam(input_tensor=input_img, targets=None) grayscale_cam = grayscale_cam[0, :] visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=opacity) results.append(visualization) if show_misclassified: # This part would require a dataset of images to check for misclassifications # For demonstration, we'll just return a placeholder message results.append("Misclassified images feature would be implemented here") return results def launch(): with gr.Blocks() as demo: gr.Markdown("# CIFAR10 ResNet18 Model with GradCAM") with gr.Row(): input_image = gr.Image(width=256, height=256, label="Input Image") output_image = gr.Image(width=256, height=256, label="GradCAM Output") with gr.Row(): prediction = gr.Textbox(label="Predicted Class") confidences = gr.Label(label="Top Class Confidences") with gr.Row(): show_gradcam = gr.Checkbox(label="Show GradCAM") num_gradcam = gr.Slider(1, 5, value=1, step=1, label="Number of GradCAM images") target_layer = gr.Slider(-2, -1, value=-2, step=1, label="Target Layer") opacity = gr.Slider(0, 1, value=0.5, label="GradCAM Opacity") with gr.Row(): show_misclassified = gr.Checkbox(label="Show Misclassified Images") num_misclassified = gr.Slider(1, 10, value=5, step=1, label="Number of Misclassified Images") num_top_classes = gr.Slider(1, 10, value=3, step=1, label="Number of Top Classes to Show") submit_btn = gr.Button("Submit") example_images = gr.Dataset( components=[input_image], samples=[["cat.jpg"], ["dog.jpg"], ["bird.jpg"], ["plane.jpg"], ["car.jpg"], ["deer.jpg"], ["frog.jpg"], ["horse.jpg"], ["ship.jpg"], ["truck.jpg"]] ) submit_btn.click( inference, inputs=[input_image, show_gradcam, num_gradcam, target_layer, opacity, show_misclassified, num_misclassified, num_top_classes], outputs=[prediction, confidences, output_image] ) demo.launch() if __name__ == "__main__": launch()