File size: 780 Bytes
c4b2b37 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import matplotlib.pyplot as plt
from generic_utils import generate_visualization
def do_lrp(transform, image, class_index=None):
fig, axs = plt.subplots(1, 2)
axs[0].imshow(image)
axs[0].axis("off")
transformed_image = transform(image)
viz = generate_visualization(
transformed_image, class_index=class_index, method="full"
)
axs[1].imshow(viz)
axs[1].axis("off")
return fig
def do_partial_lrp(transform, image, class_index=None):
fig, axs = plt.subplots(1, 2)
axs[0].imshow(image)
axs[0].axis("off")
transformed_image = transform(image)
viz = generate_visualization(
transformed_image, class_index=class_index, method="last_layer"
)
axs[1].imshow(viz)
axs[1].axis("off")
return fig
|