NeuralVista / app.py
BhumikaMak's picture
Fix: incorrect path
fa45d53
raw
history blame
3.37 kB
import numpy as np
import cv2
import os
from PIL import Image
import torchvision.transforms as transforms
import gradio as gr
from yolov5 import xai_yolov5
from yolov8 import xai_yolov8s
def process_image(image, yolo_versions=["yolov5"]):
image = np.array(image)
image = cv2.resize(image, (640, 640))
result_images = []
for yolo_version in yolo_versions:
if yolo_version == "yolov5":
result_images.append(xai_yolov5(image))
elif yolo_version == "yolov8s":
result_images.append(xai_yolov8s(image))
else:
result_images.append((Image.fromarray(image), f"{yolo_version} not yet implemented."))
return result_images
sample_images = {
"Sample 1": os.path.join(os.getcwd(), "/data/xai/sample1.jpeg"),
"Sample 2": os.path.join(os.getcwd(), "/data/xai/sample2.jpg")
}
"""
interface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload an Image"),
gr.CheckboxGroup(
choices=["yolov5", "yolov8s"],
value=["yolov5"], # Set the default value (YOLOv5 checked by default)
label="Select Model(s)",
),
gr.Dropdown(
choices=list(sample_images.keys()),
label="Select a Sample Image",
type="value",
interactive=True,
),
],
outputs=gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
title="Visualising the key image features that drive decisions with our explainable AI tool.",
description="XAI: Upload an image or select a sample to visualize object detection of your models.",
)
def main_logic(uploaded_image, selected_models, sample_selection):
# If the user selects a sample image, use that instead of the uploaded one
if sample_selection:
image = load_sample(sample_selection)
else:
image = uploaded_image
# Call the processing function
return process_image(image, selected_models)
interface.launch()
"""
def load_sample_image(sample_name):
if sample_name in sample_images:
try:
return Image.open(sample_images[sample_name]) # Load and return the image
except Exception as e:
print(f"Error loading image: {e}")
return None
return None
# Gradio interface
with gr.Blocks() as interface:
gr.Markdown("# Visualizing Key Features with Explainable AI")
gr.Markdown("Upload an image or select a sample image to visualize object detection.")
with gr.Row():
uploaded_image = gr.Image(type="pil", label="Upload an Image")
sample_selection = gr.Dropdown(
choices=list(sample_images.keys()),
label="Select a Sample Image",
type="value",
)
sample_display = gr.Image(label="Sample Image Preview", value=None)
sample_selection.change(fn=load_sample_image, inputs=sample_selection, outputs=sample_display)
selected_models = gr.CheckboxGroup(
choices=["yolov3", "yolov8s"],
value=["yolov5"], # Default model
label="Select Model(s)",
)
result_gallery = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500)
gr.Button("Run").click(
fn=process_image,
inputs=[uploaded_image, selected_models],
outputs=result_gallery,
)
interface.launch()