Spaces:
Runtime error
Runtime error
File size: 4,318 Bytes
f64a3c9 eeea929 25d8a99 4f17ac7 eeea929 25d8a99 4f17ac7 25d8a99 eeea929 25d8a99 eeea929 25d8a99 eeea929 25d8a99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
import numpy as np
import torch
from PIL import Image
from models.custom_resnet import CustomResNet
# Run Interface
def run_inference(input_image, gradcam=False, gradcam_layer=3, gradcam_num = 3, gradcam_opacity=0.5, misclassified_num=5, top_classes=10):
"""Run inference on a CIFAR-10 image.
Args:
image: The image to be classified.
gradcam: Whether to show GradCAM images.
gradcam_layer: The layer from which to generate GradCAM images.
gradcam_opacity: The opacity of the GradCAM images.
misclassified: Whether to show misclassified images.
misclassified_num: The number of misclassified images to show.
top_classes: The number of top classes to show.
Returns:
The classification results, including the predicted class, the top classes, and the GradCAM images (if requested).
"""
# # Load the CIFAR-10 model
# model = CustomResNet()
# checkpoint = torch.load('weight/epoch=2-step=294.ckpt')
# model.load_state_dict(checkpoint['model_state_dict'])
# # Classify the image
# prediction = model.predict(image)
# predicted_class = np.argmax(prediction)
# # Get the top classes
# top_classes = prediction.argsort()[-top_classes:][::-1]
# # Generate GradCAM images, if requested
# if gradcam:
# gradcam_images = []
# for layer in range(model.layers.shape[0]):
# gradcam_image = gradcam(model, image, layer, gradcam_opacity)
# gradcam_images.append(gradcam_image)
# # Get the misclassified images, if requested
# misclassified_images = []
# for i in range(len(prediction)):
# if prediction[i] != y_test[i]:
# misclassified_images.append(image[i])
# Placeholder for top classes
top_classes = {"dog" : 0.90, "cat": 0.10}
# Placeholder for GradCAM images
gradcam_images = []
if gradcam:
# Generate GradCAM images for the specified layer and number
for i in range(gradcam_num):
gradcam_images.append(np.random.rand(32, 32, 3)) # Example random image
# Placeholder for misclassified images
misclassified_images = []
if misclassified_num > 0:
# Get misclassified images
for i in range(misclassified_num):
misclassified_images.append((np.random.rand(32, 32, 3), 'caption')) # Example random image
# Return the classification results
return top_classes, gradcam_images if gradcam else [], misclassified_images
# Gradio Interface
input_image = gr.Image(shape=(32, 32), label="Upload Image", info="Upload a CIFAR-10 image to be classified.")
gradcam= gr.Checkbox(label="View GradCAM images?", info="Whether to show GradCAM images.")
gradcam_layer = gr.Dropdown(["1", "2", "3"], value="2", label="GradCAM Layer", info="The layer from which to generate GradCAM images.")
gradcam_num = gr.Slider(label="Number of GradCAM images", minimum=1, maximum=10, step=1, info="The number of GradCAM images to show.")
gradcam_opacity = gr.Slider(label="GradCAM opacity", minimum=0.0, maximum=1.0, step=0.01, info="The opacity of the GradCAM images.")
misclassified_num = gr.Slider(label="Number of Misclassified images", minimum=0, maximum=10, step=1, info="The number of misclassified images to show.")
top_classes = gr.Slider(label="Number of top classes to show", minimum=1, maximum=10, step=1, info="The number of top classes to show.")
output_label = gr.Label(num_top_classes=3, label="Top Classes")
output_gradcam_gallery = gr.Gallery(object_fit="fit", columns=4, height=280, label="GradCam Galery")
output_misclassified_gallery = gr.Gallery(object_fit="fit", columns=4, height=280, label="Misclassified Images")
interface = gr.Interface(
fn=run_inference,
inputs=[
input_image,
gradcam,
gradcam_layer,
gradcam_num,
gradcam_opacity,
misclassified_num,
top_classes
],
outputs=[output_label, output_gradcam_gallery, output_misclassified_gallery],
examples=[
['assets/0001.jpg', True,"3", 4, 0.5, 3, 3, 2],
['assets/0002.jpg', False, "2", 1, 0.3, 1, 2, 2],
['assets/0003.jpg', True, "2", 1, 0.3, 1, 2, 2],
],
title="Cifar-10 Inference with GradCAM",
description="This is a CIFAR-10 image classifier using custom resnet. Upload a CIFAR-10 image and it will be classified into one of 10 categories.",)
interface.launch(share=False) |