# Based on the learnt CNN kernels, this script will aid in generating a learnt kernel pattern. # Attempt 1, did not work well. import matplotlib.pyplot as plt # Here we should be able to determine what weighting each part of the image aids in the detection of the goal. # And how these change over time. # https://www.youtube.com/watch?v=ST9NjnKKvT8 # This video aims to solve this problem, by going over the heatmaps of CNNs. from torchvision import transforms from dataset_creation import normal_transforms from model import MakiAlexNet import numpy as np import cv2, torch, os from tqdm import tqdm import time TEST_IMAGE = "dataset/root/train/left1_frame_0.jpg" MODEL_PARAMS = "alexnet_cognitive.pth" all_processing_files = os.listdir(os.path.join(os.getcwd(), "./dataset/root/train")) model = MakiAlexNet() model.load_state_dict(torch.load(MODEL_PARAMS)) model.eval() print("Model armed and ready for evaluation.") # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor, "\t", model.state_dict()[param_tensor].size()) for image_file in tqdm(all_processing_files): # Showcase and load image from file. abs_file_path = os.path.join(os.getcwd(), "./dataset/root/train", image_file) image = cv2.imread(abs_file_path) # print(image.shape) # cv2.imshow("test", image) # cv2.waitKey(5000) print("Image input shape of the matrix before: ", image.shape) image = torch.unsqueeze(torch.tensor(image.astype(np.float32)), 0) # Convert image to tensor with float32, and extended batch size dimension. (Batch, Channel, W,H) image = torch.einsum("BWHC->BCWH", image) print("Image input shape of the matrix after: ", image.shape) conv1_output = model.conv1(image) print("Output shape of the matrix: ", conv1_output.shape) # Handling image convolutions conv1_formatted = torch.einsum("BCWH->WHC", conv1_output) print(f"Formatted shape of matrix is: {conv1_formatted.shape}") # Assuming your 3D array is named 'data' num_channels = conv1_formatted.shape[2] # Get the number of channels (96) max_rows = 5 # Set a maximum number of rows (optional) rows = min(max_rows, int(np.sqrt(num_channels))) # Limit rows to a maximum cols = int(np.ceil(num_channels / rows)) fig, axes = plt.subplots(rows, cols, figsize=(12, 12)) # Create a grid of subplots DATASET_OUTPUT_PATH = "./dataset/visualisation" merged_frames = np.zeros((224,224)) image_file_dir = abs_file_path.split(".jpg")[0].split("/")[-1] if not os.path.isdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)): os.mkdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)) # make new directory. for i in range(rows): for j in range(cols): channel_idx = i * cols + j # Calculate index based on row and column if channel_idx < num_channels: # Check if within channel range channel_data = conv1_formatted[:, :, channel_idx] channel_data = channel_data.detach().numpy() print(f"Channel Data shape dimension: {channel_data.shape}") # channel_data = np.mean(channel_data, axis=2) # Get the mean of each third dimension, so mean on channels, if H,W,C -> H,W channel_data = cv2.resize(channel_data, (224, 224)) # Accumulate normalized channel data # take threshold values of channel data to add to merged frames, if above a specific point. # ret, channel_data = cv2.threshold(channel_data, 120, 255, cv2.THRESH_BINARY) merged_frames += channel_data # # Save the image data matrix. # image_filename = f"{int(time.time())}_output_{channel_idx}.jpg" # image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_filename) # plt.imsave(image_path, channel_data) # print(f"Image path saved at {image_path}") # Ensure final merged_frames is also normalized merged_frames /= (np.max(merged_frames) * .8) # Thresholding the main images that causes this highlight. merged_frames_gray = merged_frames.astype(np.uint8) # No conversion needed, use as-is # 9merged_frames = cv2.adaptiveThreshold(merged_frames_gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2) image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_mask.jpg") plt.imsave(image_path, merged_frames_gray, cmap='gray') # merged_frames = merged_frames.astype(np.uint8) heatmap_color = cv2.applyColorMap(merged_frames_gray, cv2.COLORMAP_JET) # Apply a colormap # # cv2.imshow("merged", heatmap_color) image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_heatmap.jpg") plt.imsave(image_path, heatmap_color) # # # Merge all images into one, normalising based on highest value, and then increasing from 54,54, 1, to 224,224,1 # cv2.waitKey(5000) plt.close() exit() # # image_tensor = normal_transforms(torch.tensor(image)) # print(image_tensor.shape) # plt.imshow(image_tensor.squeeze())