import matplotlib.pyplot as plt from visualization import generate_visualization, print_top_classes def do_explain(transform, image, class_index=None): fig, axs = plt.subplots(1, 2) axs[0].imshow(image) axs[0].axis("off") transformed_image = transform(image) viz = generate_visualization( transformed_image, class_index=class_index ) predict = print_top_classes(transformed_image) axs[1].imshow(viz) axs[1].axis("off") return fig, predict