File size: 5,065 Bytes
4ec6f12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Based on the learnt CNN kernels, this script will aid in generating a learnt kernel pattern.

# Attempt 1, did not work well.

import matplotlib.pyplot as plt
# Here we should be able to determine what weighting each part of the image aids in the detection of the goal.
# And how these change over time.

# https://www.youtube.com/watch?v=ST9NjnKKvT8
# This video aims to solve this problem, by going over the heatmaps of CNNs.


from torchvision import transforms
from dataset_creation import normal_transforms
from model import MakiAlexNet
import numpy as np
import cv2, torch, os
from tqdm import tqdm
import time

TEST_IMAGE = "dataset/root/train/left1_frame_0.jpg"
MODEL_PARAMS = "alexnet_cognitive.pth"
all_processing_files = os.listdir(os.path.join(os.getcwd(), "./dataset/root/train"))

model = MakiAlexNet()

model.load_state_dict(torch.load(MODEL_PARAMS))
model.eval()
print("Model armed and ready for evaluation.")

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())




for image_file in tqdm(all_processing_files):

  # Showcase and load image from file.
  abs_file_path = os.path.join(os.getcwd(), "./dataset/root/train", image_file)
  image = cv2.imread(abs_file_path)
  # print(image.shape)
  # cv2.imshow("test", image)
  # cv2.waitKey(5000)


  print("Image input shape of the matrix before: ", image.shape)
  image = torch.unsqueeze(torch.tensor(image.astype(np.float32)), 0)  # Convert image to tensor with float32, and extended batch size dimension.  (Batch, Channel, W,H)
  image = torch.einsum("BWHC->BCWH", image)
  print("Image input shape of the matrix after: ", image.shape)
  conv1_output = model.conv1(image)
  print("Output shape of the matrix: ", conv1_output.shape)


  # Handling image convolutions

  conv1_formatted = torch.einsum("BCWH->WHC", conv1_output)
  print(f"Formatted shape of matrix is: {conv1_formatted.shape}")


  # Assuming your 3D array is named 'data'
  num_channels = conv1_formatted.shape[2]  # Get the number of channels (96)
  max_rows = 5  # Set a maximum number of rows (optional)
  rows = min(max_rows, int(np.sqrt(num_channels)))  # Limit rows to a maximum
  cols = int(np.ceil(num_channels / rows))

  fig, axes = plt.subplots(rows, cols, figsize=(12, 12))  # Create a grid of subplots

  DATASET_OUTPUT_PATH = "./dataset/visualisation"
  merged_frames = np.zeros((224,224))
  image_file_dir = abs_file_path.split(".jpg")[0].split("/")[-1]
  if not os.path.isdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir)):
    os.mkdir(os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir))  # make new directory.


  for i in range(rows):
    for j in range(cols):
      channel_idx = i * cols + j  # Calculate index based on row and column
      if channel_idx < num_channels:  # Check if within channel range
        channel_data = conv1_formatted[:, :, channel_idx]
        channel_data = channel_data.detach().numpy()
        print(f"Channel Data shape dimension: {channel_data.shape}")
        # channel_data = np.mean(channel_data, axis=2)
        # Get the mean of each third dimension,  so mean on channels, if H,W,C -> H,W
        channel_data = cv2.resize(channel_data, (224, 224))

        # Accumulate normalized channel data
        # take threshold values of channel data to add to merged frames, if above a specific point.
        # ret, channel_data = cv2.threshold(channel_data, 120, 255, cv2.THRESH_BINARY)
        merged_frames += channel_data




        # # Save the image data matrix.
        # image_filename = f"{int(time.time())}_output_{channel_idx}.jpg"
        # image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_filename)
        # plt.imsave(image_path, channel_data)
        # print(f"Image path saved at {image_path}")




  # Ensure final merged_frames is also normalized

  merged_frames /= (np.max(merged_frames) * .8)

  # Thresholding the main images that causes this highlight.

  merged_frames_gray = merged_frames.astype(np.uint8)  # No conversion needed, use as-is
  # 9merged_frames = cv2.adaptiveThreshold(merged_frames_gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)


  image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_mask.jpg")

  plt.imsave(image_path, merged_frames_gray, cmap='gray')

  # merged_frames = merged_frames.astype(np.uint8)
  heatmap_color = cv2.applyColorMap(merged_frames_gray, cv2.COLORMAP_JET)  # Apply a colormap
  #
  # cv2.imshow("merged", heatmap_color)
  image_path = os.path.join(os.getcwd(), DATASET_OUTPUT_PATH, image_file_dir, image_file_dir+"conv1_heatmap.jpg")
  plt.imsave(image_path, heatmap_color)
  #
  # # Merge all images into one, normalising based on highest value, and then increasing from 54,54, 1, to 224,224,1
  # cv2.waitKey(5000)
  plt.close()

exit()

#
# image_tensor = normal_transforms(torch.tensor(image))
# print(image_tensor.shape)
# plt.imshow(image_tensor.squeeze())