import matplotlib.pyplot as plt from app import inference, examples from PIL import Image plt.rcParams["figure.figsize"] = (11,2) title = ["CAM", "ROLLOUT"] fig_resnet, axis_resnet = plt.subplots(1, len(examples)) plots = [plt.subplots(1, len(examples)) for _ in range(2)] for i, image_path in enumerate(examples): image = Image.open(image_path) result = inference(image) for j, (fig, axis) in enumerate(plots): axis[i].imshow(result[2*j+1]) axis[i].set_title(result[2*j]) axis[i].set_axis_off() for i, (plot, title) in enumerate(zip(plots, title)): # plot[0].suptitle(title) plot[0].savefig(f"{title}.png")