|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def simple_cmfd_decoder(busterNetModel, rgb): |
|
"""A simple BusterNet CMFD decoder""" |
|
|
|
single_sample_batch = np.expand_dims(rgb, axis=0) |
|
|
|
pred = busterNetModel.predict(single_sample_batch)[0] |
|
return pred |
|
|
|
|
|
def visualize_result(rgb, gt, pred, figsize=(12, 4), title=None): |
|
"""Visualize raw input, ground truth, and BusterNet result""" |
|
fig = plt.figure(figsize=figsize) |
|
|
|
plt.subplot(1, 3, 1) |
|
plt.imshow(rgb) |
|
plt.title("input image") |
|
plt.subplot(1, 3, 2) |
|
plt.title("ground truth") |
|
plt.imshow(gt) |
|
plt.subplot(1, 3, 3) |
|
plt.imshow(pred) |
|
plt.title("busterNet pred") |
|
return fig |
|
|