AlexNet_CNN_Visualisation /
makiisthebes's picture
Upload 18 files
4ec6f12 verified
history blame contribute delete
No virus
5.76 kB
import cv2, os, torch, re
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
import numpy as np
from model import MakiAlexNet
from tqdm import tqdm
# from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
TEST_IMAGE = "dataset/root/train/left1_frame_10.jpg"
MODEL_PARAMS = "alexnet_cognitive.pth"
GIF_STORE = "dataset/gifs/"
TRAIN_STORE = "dataset/root/train/"
model = MakiAlexNet()
# Make model run on cuda if available.
if torch.cuda.is_available():
model = model.cuda()
print("Running on cuda")
for name, module in model.named_modules():
# Print the layer name
def extract_file_paths(filename):
"""With aid from, regex."""
extractor_reg = r"(left|right)([0-9]+)(_frame_)([0-9]+)"
result =, filename)
frame_no =
frame_name =
video_no =
return frame_no, frame_name, video_no
def create_mp4_from_frames(file_name, frames):
"""Generate MP4/GIF file with the collection of frames given with a duration of 2000 msec. """
print("Sorted frames: ", sorted(frames))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
height, width, _ = cv2.imread(frames[0]).shape
fps = 20 # Adjust the frames per second (FPS) as needed
video_path = os.path.join(os.getcwd(), "dataset", "gifs", f"{file_name}.mp4")
video = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
for frame_path in sorted(frames):
# Convert BRG to RGB
image = cv2.imread(frame_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# if image.dtype != np.uint8:
# image = (image * 255).astype(np.uint8) # Convert to uint8
# Release the VideoWriter
current_video_name = None
selected_frames = [] # stores matrices for the GIF generation.
for image_filename in ["left1_frame_5.jpg"]: # tqdm(sorted(os.listdir(TRAIN_STORE)), desc="Running Images"): # :
frame_no, frame_name, video_no = extract_file_paths(image_filename)
obtained_video_name = video_no+"vid"+frame_name
if current_video_name != obtained_video_name:
# We have a new video sequence, so save current sequences and name
if selected_frames:
filename = f"{current_video_name}"
# Create gif from the frames.
if current_video_name:
create_mp4_from_frames(filename, selected_frames)
# Clear frames and hand off to new handle.
selected_frames = []
current_video_name = obtained_video_name
# With the number and name of the file paths, we can then determine which should be part of the specific GIF file.
# f"frame_no,fileno,video_no.gif"
img = cv2.imread(os.path.join(TRAIN_STORE, image_filename))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.unsqueeze(torch.tensor(img.astype(np.float32)), 0) # Convert image to tensor with float32, and extended batch size dimension. (Batch, Channel, W,H)
X = torch.einsum("BWHC->BCWH", img)
if torch.cuda.is_available():
X = X.cuda()
output = model(X)
# print(output)
# print(model.layer_outputs)
conv = model.layer_outputs['Conv2d']
conv = torch.einsum("BCWH->BWHC", conv).cpu().detach().numpy()
# print(conv.shape) # torch.Size([1, 256, 12, 12])
# conv = conv.squeeze(0)
# print(conv.shape) # torch.Size([256, 12, 12])
scale = 224 / 12 # 256x5x5 after this additional.
plt.figure(figsize=(16, 16))
total_mat = None
for i in range(256):
plt.subplot(16, 16, i + 1)
plt.imshow(zoom(conv[0, :,:,i], zoom=(scale, scale)), cmap='jet', alpha=0.3)
# wait for user to press a key
# mat = zoom(conv[0, :, :, i], zoom=(scale, scale))
# threshold = np.percentile(mat.flatten(), TOP_ACCURACY_PERCENTILE)
# # The Lower threshold is to zero, the more specific the look is shown.
# mask = mat > threshold
# # OR: filter_map = np.where(filter_map <= threshold, 0, filter_map)
# # Rescale remaining values (adjust new_range if needed)
# new_range = 1 # Adjust based on your desired final range
# filter_map = np.where(mask, (mat - threshold) / (mat.max() - threshold) * new_range, 0)
# # I just add all the maps together, which is really noisy.
# if type(total_mat) != type(None):
# total_mat += filter_map
# else:
# total_mat = filter_map
# # Normalize based on largest value,
# # Store this image in a collection, in which a GIF will be made, that lasts at least 2 seconds.
# total_mat = total_mat / abs(np.max(total_mat))
# #
# image = img.squeeze(0) # .detach().numpy().astype(np.float32)
# plt.imshow(plt.imread(os.path.join(os.getcwd(), "dataset/root/train", image_filename))) # full path needed
# plt.imshow(total_mat, cmap='jet', alpha=0.3)
# # selected_frames.append()
# filename = frame_name+frame_no+video_no+".jpg"
# file_path = os.path.join(os.getcwd(), "dataset/gifs/raw/", filename)
# plt.savefig(file_path)
# selected_frames.append(file_path)
# plt.figure(figsize=(16, 16))
# for i in range(36):
# plt.subplot(6, 6, i + 1)
# plt.imshow(cv2.imread(TEST_IMAGE))
# plt.imshow(zoom(conv[0, :,:,i], zoom=(scale, scale)), cmap='jet', alpha=0.3)