File size: 4,318 Bytes
f64a3c9
eeea929
25d8a99
 
4f17ac7
eeea929
25d8a99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f17ac7
 
 
 
 
25d8a99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eeea929
25d8a99
 
 
 
 
 
 
eeea929
25d8a99
eeea929
25d8a99
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import numpy as np
import torch
from PIL import Image
from models.custom_resnet import CustomResNet


# Run Interface
def run_inference(input_image, gradcam=False, gradcam_layer=3, gradcam_num = 3, gradcam_opacity=0.5, misclassified_num=5, top_classes=10):
  """Run inference on a CIFAR-10 image.

  Args:
    image: The image to be classified.
    gradcam: Whether to show GradCAM images.
    gradcam_layer: The layer from which to generate GradCAM images.
    gradcam_opacity: The opacity of the GradCAM images.
    misclassified: Whether to show misclassified images.
    misclassified_num: The number of misclassified images to show.
    top_classes: The number of top classes to show.

  Returns:
    The classification results, including the predicted class, the top classes, and the GradCAM images (if requested).
  """
#   # Load the CIFAR-10 model

  # model = CustomResNet()
  # checkpoint = torch.load('weight/epoch=2-step=294.ckpt')
  # model.load_state_dict(checkpoint['model_state_dict'])


#   # Classify the image
#   prediction = model.predict(image)
#   predicted_class = np.argmax(prediction)

#   # Get the top classes
#   top_classes = prediction.argsort()[-top_classes:][::-1]

#   # Generate GradCAM images, if requested
#   if gradcam:
#     gradcam_images = []
#     for layer in range(model.layers.shape[0]):
#       gradcam_image = gradcam(model, image, layer, gradcam_opacity)
#       gradcam_images.append(gradcam_image)

#   # Get the misclassified images, if requested
#   misclassified_images = []
#   for i in range(len(prediction)):
#     if prediction[i] != y_test[i]:
#       misclassified_images.append(image[i])

  # Placeholder for top classes
  top_classes = {"dog" : 0.90, "cat": 0.10}

  # Placeholder for GradCAM images
  gradcam_images = []
  if gradcam:
      # Generate GradCAM images for the specified layer and number
      for i in range(gradcam_num):
          gradcam_images.append(np.random.rand(32, 32, 3))  # Example random image

  # Placeholder for misclassified images
  misclassified_images = []
  if misclassified_num > 0:
      # Get misclassified images
      for i in range(misclassified_num):
          misclassified_images.append((np.random.rand(32, 32, 3), 'caption'))  # Example random image


  # Return the classification results
  return top_classes, gradcam_images if gradcam else [], misclassified_images


# Gradio Interface
input_image = gr.Image(shape=(32, 32), label="Upload Image", info="Upload a CIFAR-10 image to be classified.")
gradcam= gr.Checkbox(label="View GradCAM images?", info="Whether to show GradCAM images.")
gradcam_layer = gr.Dropdown(["1", "2", "3"], value="2", label="GradCAM Layer", info="The layer from which to generate GradCAM images.")
gradcam_num = gr.Slider(label="Number of GradCAM images", minimum=1, maximum=10, step=1, info="The number of GradCAM images to show.")
gradcam_opacity = gr.Slider(label="GradCAM opacity", minimum=0.0, maximum=1.0, step=0.01, info="The opacity of the GradCAM images.")
misclassified_num = gr.Slider(label="Number of Misclassified images", minimum=0, maximum=10, step=1, info="The number of misclassified images to show.")
top_classes = gr.Slider(label="Number of top classes to show", minimum=1, maximum=10, step=1, info="The number of top classes to show.")

output_label = gr.Label(num_top_classes=3, label="Top Classes")
output_gradcam_gallery = gr.Gallery(object_fit="fit", columns=4, height=280, label="GradCam Galery")
output_misclassified_gallery = gr.Gallery(object_fit="fit", columns=4, height=280, label="Misclassified Images")

interface = gr.Interface(
    fn=run_inference,
    inputs=[
        input_image,
        gradcam,
        gradcam_layer,
        gradcam_num,
        gradcam_opacity,
        misclassified_num,
        top_classes
    ],
    outputs=[output_label, output_gradcam_gallery, output_misclassified_gallery],
    examples=[
        ['assets/0001.jpg', True,"3", 4, 0.5, 3, 3, 2],
        ['assets/0002.jpg', False, "2", 1, 0.3, 1, 2, 2],
        ['assets/0003.jpg', True, "2", 1, 0.3, 1, 2, 2],
    ],
    title="Cifar-10 Inference with GradCAM",
    description="This is a CIFAR-10 image classifier using custom resnet. Upload a CIFAR-10 image and it will be classified into one of 10 categories.",)
interface.launch(share=False)