AlexNet_CNN_Visualisation / model_visualisation.py
makiisthebes's picture
Upload 18 files
4ec6f12 verified
raw
history blame contribute delete
No virus
5.07 kB
# Based on the learnt CNN kernels, this script will aid in generating a learnt kernel pattern.
# Attempt 1, did not work well.
import matplotlib.pyplot as plt
# Here we should be able to determine what weighting each part of the image aids in the detection of the goal.
# And how these change over time.
# https://www.youtube.com/watch?v=ST9NjnKKvT8
# This video aims to solve this problem, by going over the heatmaps of CNNs.
from torchvision import transforms
from dataset_creation import normal_transforms
from model import MakiAlexNet
import numpy as np
import cv2, torch, os
from tqdm import tqdm
import time
TEST_IMAGE = "dataset/root/train/left1_frame_0.jpg"
MODEL_PARAMS = "alexnet_cognitive.pth"
all_processing_files = os.listdir(os.path.join(os.getcwd(), "./dataset/root/train"))
model = MakiAlexNet()
model.load_state_dict(torch.load(MODEL_PARAMS))
model.eval()
print("Model armed and ready for evaluation.")
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
for image_file in tqdm(all_processing_files):
# Showcase and load image from file.
abs_file_path = os.path.join(os.getcwd(), "./dataset/root/train", image_file)
image = cv2.imread(abs_file_path)
# print(image.shape)
# cv2.imshow("test", image)
# cv2.waitKey(5000)
print("Image input shape of the matrix before: ", image.shape)
image = torch.unsqueeze(torch.tensor(image.astype(np.float32)), 0) # Convert image to tensor with float32, and extended batch size dimension. (Batch, Channel, W,H)
image = torch.einsum("BWHC->BCWH", image)
print("Image input shape of the matrix after: ", image.shape)
conv1_output = model.conv1(image)
print("Output shape of the matrix: ", conv1_output.shape)
# Handling image convolutions
conv1_formatted = torch.einsum("BCWH->WHC", conv1_output)
print(f"Formatted shape of matrix is: {conv1_formatted.shape}")
# Assuming your 3D array is named 'data'
num_channels = conv1_formatted.shape[2] # Get the number of channels (96)
max_rows = 5 # Set a maximum number of rows (optional)
rows = min(max_rows, int(np.sqrt(num_channels))) # Limit rows to a maximum
cols = int(np.ceil(num_channels / rows))
fig, axes = plt.subplots(rows, cols, figsize=(12, 12)) # Create a grid of subplots
DATASET_OUTPUT_PATH = "./dataset/visualisation"
merged_frames = np.zeros((224,224))
image_file_dir = abs_file_path.split(".jpg")[0].split("/")[-1]
if not os.path.isdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)):
os.mkdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)) # make new directory.
for i in range(rows):
for j in range(cols):
channel_idx = i * cols + j # Calculate index based on row and column
if channel_idx < num_channels: # Check if within channel range
channel_data = conv1_formatted[:, :, channel_idx]
channel_data = channel_data.detach().numpy()
print(f"Channel Data shape dimension: {channel_data.shape}")
# channel_data = np.mean(channel_data, axis=2)
# Get the mean of each third dimension, so mean on channels, if H,W,C -> H,W
channel_data = cv2.resize(channel_data, (224, 224))
# Accumulate normalized channel data
# take threshold values of channel data to add to merged frames, if above a specific point.
# ret, channel_data = cv2.threshold(channel_data, 120, 255, cv2.THRESH_BINARY)
merged_frames += channel_data
# # Save the image data matrix.
# image_filename = f"{int(time.time())}_output_{channel_idx}.jpg"
# image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_filename)
# plt.imsave(image_path, channel_data)
# print(f"Image path saved at {image_path}")
# Ensure final merged_frames is also normalized
merged_frames /= (np.max(merged_frames) * .8)
# Thresholding the main images that causes this highlight.
merged_frames_gray = merged_frames.astype(np.uint8) # No conversion needed, use as-is
# 9merged_frames = cv2.adaptiveThreshold(merged_frames_gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_mask.jpg")
plt.imsave(image_path, merged_frames_gray, cmap='gray')
# merged_frames = merged_frames.astype(np.uint8)
heatmap_color = cv2.applyColorMap(merged_frames_gray, cv2.COLORMAP_JET) # Apply a colormap
#
# cv2.imshow("merged", heatmap_color)
image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_heatmap.jpg")
plt.imsave(image_path, heatmap_color)
#
# # Merge all images into one, normalising based on highest value, and then increasing from 54,54, 1, to 224,224,1
# cv2.waitKey(5000)
plt.close()
exit()
#
# image_tensor = normal_transforms(torch.tensor(image))
# print(image_tensor.shape)
# plt.imshow(image_tensor.squeeze())