TSAIGradcam / app.py
ibrim's picture
Update app.py
7580432 verified
from gradio_utils import *
def process_images_gradcam(show_gradcam, gradcam_count, gradcam_layer, gradcam_opacity):
if show_gradcam:
inv_normalize = transforms.Normalize(
mean=[-1.9899, -1.9844, -1.7111],
std=[4.0486, 4.1152, 3.8314])
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
misclassified_data = get_misclassified_data(modelfin, "cpu", test_loader)
if gradcam_layer=="1":
images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer1[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
if gradcam_layer=="2":
images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer2[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
if gradcam_layer=="3":
images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer3[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
if gradcam_layer=="4":
images = display_gradcam_output(misclassified_data, classes, inv_normalize, modelfin, target_layers= [modelfin.model.layer4[-1]], targets=None, number_of_samples=gradcam_count, transparency=gradcam_opacity)
return images
def process_images_misclass(show_misclassify, misclassify_count):
if show_misclassify:
misclassified_data = get_misclassified_data(modelfin, "cpu", test_loader)
image = display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=misclassify_count)
return image
def predict_classes(upload_image, top_classes):
transform = transforms.Compose([
transforms.Resize((32, 32)), # Resize to 32x32 pixels
transforms.ToTensor(), # Convert image to tensor
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], # CIFAR-10 normalization
std=[0.2023, 0.1994, 0.2010])])
# Load and transform an image
image = upload_image
image = transform(image)
image = image.unsqueeze(0)
device = next(modelfin.parameters()).device
image = image.to(device)
# Ensure the model is in evaluation mode
modelfin.eval()
# Disable gradient computation for inference
with torch.no_grad():
output = modelfin(image)
# Get the top 5 predictions and their indices
probabilities = torch.nn.functional.softmax(output, dim=1)
top_prob, top_catid = torch.topk(probabilities, top_classes)
# CIFAR-10 classes
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# Initialize an empty string to collect predictions
predictions_str = ""
# Collect top 5 predictions in the string with line breaks
for i in range(top_prob.size(1)):
predictions_str += f"{classes[top_catid[0][i]]}: {top_prob[0][i].item()*100:.2f}%\n"
# Print or return the complete predictions string
return predictions_str
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
show_gradcam = gr.Checkbox(label="Show GradCAM Images?")
gradcam_count = gr.Number(label="How many GradCAM images?", value=1, precision=0)
gradcam_layer = gr.Radio(choices=["1", "2", "3", "4"], label="Choose a layer", value=4)
gradcam_opacity = gr.Slider(minimum=0, maximum=1, label="Opacity of overlay", value=0.5)
# with gr.Column():
# show_misclassified = gr.Checkbox(label="Show Misclassified Images?")
# misclassified_count = gr.Number(label="How many misclassified images?", value=1, precision=0)
#uploaded_images = gr.File(label="Upload New Images", type="file", accept="image/*", multiple=True)
#top_classes = gr.Number(label="How many top classes to show?", value=5, minimum=1, maximum=10, precision=0)
submit_button = gr.Button("GradCam")
outputs = gr.Image(label="Output")
show_misclassify = gr.Checkbox(label="Show misclassified images?")
misclassify_count=gr.Number(label="How many misclassified images?")
submit_button_misclass = gr.Button("Misclassified")
outputs_misclass = gr.Image(label="Output")
upload_image = gr.Image(label="Upload your image", interactive = True, type='pil')
top_classes = gr.Number(label="How many top classes would you like to see?", maximum=10)
upload_btn = gr.Button("Classify your image")
show_classes = gr.Textbox(label="Your top classes", interactive=False)
submit_button_misclass.click(
process_images_misclass,
inputs=[show_misclassify, misclassify_count],
outputs=outputs_misclass
)
submit_button.click(
process_images_gradcam,
inputs=[show_gradcam, gradcam_count, gradcam_layer, gradcam_opacity],
outputs=outputs
)
upload_btn.click(
predict_classes,
inputs=[upload_image, top_classes],
outputs=show_classes
)
demo.launch()