import numpy as np import matplotlib.pyplot as plt def simple_cmfd_decoder(busterNetModel, rgb): """A simple BusterNet CMFD decoder""" # 1. expand an image to a single sample batch single_sample_batch = np.expand_dims(rgb, axis=0) # 2. perform busterNet CMFD pred = busterNetModel.predict(single_sample_batch)[0] return pred def visualize_result(rgb, gt, pred, figsize=(12, 4), title=None): """Visualize raw input, ground truth, and BusterNet result""" fig = plt.figure(figsize=figsize) plt.subplot(1, 3, 1) plt.imshow(rgb) plt.title("input image") plt.subplot(1, 3, 2) plt.title("ground truth") plt.imshow(gt) plt.subplot(1, 3, 3) plt.imshow(pred) plt.title("busterNet pred") return fig