contrails_demo / contrails_predictions.py
pleasantville's picture
Update contrails_predictions.py
7cafa82 verified
import numpy as np
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image
import io
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches
def fig2img(fig, dpi=300):
"""Convert a Matplotlib figure to a PIL Image and return it"""
buf = io.BytesIO()
fig.savefig(buf, dpi=dpi)
buf.seek(0)
img = Image.open(buf)
return img
# Function to load a band file
def load_band_file(path):
return np.load(path)
# Function to load a mask file
def load_mask_file(path):
return np.load(path)
# Function to sort all the band file paths from gradio object
def extract_band_number(path):
try:
# Attempt to extract the band number assuming the filename format "band_XX.npy"
return int(path.split("_")[-1].split(".")[0])
except ValueError:
# If conversion fails, log an error or handle it as appropriate
print(f"Error processing file path: {path}")
return float('inf') # Return a value that ensures this file is sorted last or handled appropriately
def dice_coef_pred(y_true, y_pred):
y_true_f = K.flatten(K.constant(y_true))
y_pred_f = K.flatten(K.constant(y_pred))
intersection = K.sum(y_true_f * y_pred_f)
return (2.0 * intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.0)
def jacard_coef_pred(y_true, y_pred):
y_true_f = K.flatten(K.constant(y_true))
y_pred_f = K.flatten(K.constant(y_pred))
intersection = K.sum(y_true_f * y_pred_f)
return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
best_model = load_model('3ch_gen1_03_-0.545_-0.503_model.h5',
compile=False)
def visualize_band_images(band_paths, figsize=(10, 10), cmap='Blues'):
# Set the figure size
sorted_band_paths = sorted(band_paths, key=extract_band_number)
fig = plt.figure(figsize=figsize)
# Loop over the bands
for i in range(9):
# Load the band data
band_data = np.load(sorted_band_paths[i])
# Extract the 5th dimension for visualization
band_data_5th = band_data[:, :, 4]
# Create a subplot for each band
plt.subplot(3, 3, i + 1)
# Use seaborn to visualize the image
sns.heatmap(band_data_5th, cmap=cmap, cbar=False)
# Set the title of the subplot
plt.title(f'Band {i + 8}')
plt.axis('off')
# Adjust the layout and display the plot
plt.tight_layout()
# Convert the figure to a PIL Image
bands_fig_image = fig2img(fig)
plt.close(fig)
return bands_fig_image
def visualize_masks(predicted_mask_avg_binary, true_mask):
combined_colors = np.zeros_like(predicted_mask_avg_binary)
combined_colors[(predicted_mask_avg_binary == 0) & (true_mask == 0)] = 0
combined_colors[(predicted_mask_avg_binary == 1) & (true_mask == 1)] = 1
combined_colors[(predicted_mask_avg_binary == 1) & (true_mask == 0)] = 2
combined_colors[(predicted_mask_avg_binary == 0) & (true_mask == 1)] = 3
# Create a custom colormap with four colors
cmap_colors = ['white', 'green', 'red', '#27c1f5cc']
cmap = ListedColormap(cmap_colors, N=4)
# Create a figure
fig, ax = plt.subplots(figsize=(12, 12))
# Display the combined_colors array using the custom colormap
ax.imshow(combined_colors, cmap=cmap)
legend_elements = [
mpatches.Patch(facecolor='green', edgecolor='black', label='correct'),
mpatches.Patch(facecolor='red', edgecolor='black', label='false'),
mpatches.Patch(facecolor='#27c1f5cc', edgecolor='black', label='unidentified')
]
# Add the legend at the bottom center
ax.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, -0.05))
ax.set_title('True vs Predicted Mask', fontsize=20, y=1.05)
# Convert the figure to a PIL Image and return it
combined_mask_img = fig2img(fig)
plt.close(fig)
return combined_mask_img
def preprocess_band_files_simple(band_paths):
band_file_paths = sorted(band_paths, key=extract_band_number)
try:
if not band_file_paths:
# If no band files are found, return None or any appropriate value
return None
band_file_paths = [path for path in band_file_paths if path.endswith(("08.npy", "12.npy", "16.npy"))]
band_images = [load_band_file(path)[..., 4] for path in band_file_paths]
band_images = np.stack(band_images, axis=-1)
band_images = (band_images - np.mean(band_images)) / np.std(band_images)
return band_images
except ValueError as e:
# Handle the ValueError, you can print a message or perform any desired action
print("Error: No band files found in the subfolder")
return None
def predict_contrails_avg(model, band_paths):
band_images = preprocess_band_files_simple(band_paths)
band_images = np.expand_dims(band_images, axis=0) # Add a batch dimension
p0 = model.predict(band_images)[0, ..., 0] # Original prediction
# Flip image left-right, make a prediction, then unflip the prediction
band_images_lr = np.flip(band_images, axis=2)
p1 = model.predict(band_images_lr)[0, ..., 0]
p1 = np.flip(p1, axis=1)
# Flip image up-down, make a prediction, then unflip the prediction
band_images_ud = np.flip(band_images, axis=1)
p2 = model.predict(band_images_ud)[0, ..., 0]
p2 = np.flip(p2, axis=0)
# Flip image left-right and up-down, make a prediction, then unflip the prediction
band_images_lr_ud = np.flip(np.flip(band_images, axis=2), axis=1)
p3 = model.predict(band_images_lr_ud)[0, ..., 0]
p3 = np.flip(np.flip(p3, axis=1), axis=0)
# Average the predictions
prediction_avg = (p0 + p1 + p2 + p3) / 4.0
return p0, prediction_avg
def create_overlay_gif(predicted_mask_avg_binary, band_paths):
# Get the path for band_08.npy
band_08_path = next((path for path in band_paths if 'band_08.npy' in path), None)
if not band_08_path:
print("band_08.npy not found in band_paths")
return
# Load the image
image = np.load(band_08_path)
image = image[..., 4] # Use the 5th channel (index 4)
min_value = np.min(image)
max_value = np.max(image)
image = (image - min_value)*255 / (max_value - min_value)
overlay_images = []
for alpha in np.linspace(0.0, 0.3, num=30):
overlay_image = image + alpha*predicted_mask_avg_binary*255
overlay_images.append(overlay_image)
# Save the overlay images as a gif
stacked_array = np.stack(overlay_images, axis=0)
imgs = [Image.fromarray(img.astype('uint8')) for img in stacked_array]
gif_path = "overlay.gif" # Specify the path where you want to save the gif
imgs[0].save(gif_path, save_all=True, append_images=imgs[1:], duration=300, loop=0)
return gif_path
def visualize_prediction_avg(mask_path, band_paths, model=best_model, figsize=(10, 10), cmap='Spectral', threshold=0.5):
# Load the mask
true_mask = np.load(mask_path)[:, :, 0]
# Get the predicted masks
p0, predicted_mask_avg = predict_contrails_avg(model, band_paths)
p0_binary = (p0 > threshold).astype(np.uint8)
predicted_mask_avg_binary = (predicted_mask_avg > threshold).astype(np.uint8)
# Visualize band images
band_vis = visualize_band_images(band_paths)
# Visualize mask images
mask_vis = visualize_masks(predicted_mask_avg_binary, true_mask)
# Print dice coefficient and IoU for original and averaged prediction
dice_original = dice_coef_pred(true_mask, p0_binary)
dice_avg = dice_coef_pred(true_mask, predicted_mask_avg_binary)
jacard_original = jacard_coef_pred(true_mask, p0_binary)
jacard_avg = jacard_coef_pred(true_mask, predicted_mask_avg_binary)
print(f"Dice Coefficient - Original: {dice_original}, Averaged: {dice_avg}")
print(f"IoU - Original: {jacard_original}, Averaged: {jacard_avg}")
# Prepare the text output
text_output = f"Dice Coefficient - Averaged: {dice_avg:.3f}\n"
text_output += f"IoU - Averaged: {jacard_avg:.3f}"
# Create overlay gif
gif_path = create_overlay_gif(predicted_mask_avg_binary, band_paths)
return band_vis, mask_vis, gif_path, text_output