AlexNet_CNN_Visualisation / visualise2.py
makiisthebes's picture
Upload 18 files
4ec6f12 verified
raw
history blame contribute delete
No virus
5.76 kB
# https://tree.rocks/get-heatmap-from-cnn-convolution-neural-network-aka-grad-cam-222e08f57a34
import cv2, os, torch, re
import matplotlib.pyplot as plt
from scipy.ndimage import zoom
import numpy as np
from model import MakiAlexNet
from tqdm import tqdm
# from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
TOP_ACCURACY_PERCENTILE = 10
TEST_IMAGE = "dataset/root/train/left1_frame_10.jpg"
MODEL_PARAMS = "alexnet_cognitive.pth"
GIF_STORE = "dataset/gifs/"
TRAIN_STORE = "dataset/root/train/"
model = MakiAlexNet()
model.load_state_dict(torch.load(MODEL_PARAMS))
model.eval()
# Make model run on cuda if available.
if torch.cuda.is_available():
model = model.cuda()
print("Running on cuda")
print(dir(model))
for name, module in model.named_modules():
# Print the layer name
print(name)
def extract_file_paths(filename):
"""With aid from https://regex101.com/, regex."""
extractor_reg = r"(left|right)([0-9]+)(_frame_)([0-9]+)"
result = re.search(extractor_reg, filename)
frame_no = result.group(4)
frame_name = result.group(1)
video_no = result.group(2)
return frame_no, frame_name, video_no
def create_mp4_from_frames(file_name, frames):
"""Generate MP4/GIF file with the collection of frames given with a duration of 2000 msec. """
print("Sorted frames: ", sorted(frames))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
height, width, _ = cv2.imread(frames[0]).shape
fps = 20 # Adjust the frames per second (FPS) as needed
video_path = os.path.join(os.getcwd(), "dataset", "gifs", f"{file_name}.mp4")
video = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
for frame_path in sorted(frames):
# Convert BRG to RGB
image = cv2.imread(frame_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# if image.dtype != np.uint8:
# image = (image * 255).astype(np.uint8) # Convert to uint8
video.write(image)
# Release the VideoWriter
video.release()
current_video_name = None
selected_frames = [] # stores matrices for the GIF generation.
for image_filename in ["left1_frame_5.jpg"]: # tqdm(sorted(os.listdir(TRAIN_STORE)), desc="Running Images"): # :
frame_no, frame_name, video_no = extract_file_paths(image_filename)
obtained_video_name = video_no+"vid"+frame_name
if current_video_name != obtained_video_name:
# We have a new video sequence, so save current sequences and name
if selected_frames:
filename = f"{current_video_name}"
# Create gif from the frames.
if current_video_name:
create_mp4_from_frames(filename, selected_frames)
# Clear frames and hand off to new handle.
selected_frames = []
current_video_name = obtained_video_name
# With the number and name of the file paths, we can then determine which should be part of the specific GIF file.
# f"frame_no,fileno,video_no.gif"
img = cv2.imread(os.path.join(TRAIN_STORE, image_filename))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.unsqueeze(torch.tensor(img.astype(np.float32)), 0) # Convert image to tensor with float32, and extended batch size dimension. (Batch, Channel, W,H)
X = torch.einsum("BWHC->BCWH", img)
if torch.cuda.is_available():
X = X.cuda()
output = model(X)
# print(output)
# print(model.layer_outputs)
conv = model.layer_outputs['Conv2d']
conv = torch.einsum("BCWH->BWHC", conv).cpu().detach().numpy()
# print(conv.shape) # torch.Size([1, 256, 12, 12])
# conv = conv.squeeze(0)
# print(conv.shape) # torch.Size([256, 12, 12])
scale = 224 / 12 # 256x5x5 after this additional.
plt.figure(figsize=(16, 16))
total_mat = None
for i in range(256):
plt.subplot(16, 16, i + 1)
plt.imshow(img.squeeze(0))
plt.imshow(zoom(conv[0, :,:,i], zoom=(scale, scale)), cmap='jet', alpha=0.3)
plt.show()
# wait for user to press a key
# mat = zoom(conv[0, :, :, i], zoom=(scale, scale))
# threshold = np.percentile(mat.flatten(), TOP_ACCURACY_PERCENTILE)
# # The Lower threshold is to zero, the more specific the look is shown.
#
# mask = mat > threshold
# # OR: filter_map = np.where(filter_map <= threshold, 0, filter_map)
#
# # Rescale remaining values (adjust new_range if needed)
# new_range = 1 # Adjust based on your desired final range
# filter_map = np.where(mask, (mat - threshold) / (mat.max() - threshold) * new_range, 0)
#
# # I just add all the maps together, which is really noisy.
# if type(total_mat) != type(None):
# total_mat += filter_map
# else:
# total_mat = filter_map
#
# # Normalize based on largest value,
# # Store this image in a collection, in which a GIF will be made, that lasts at least 2 seconds.
# total_mat = total_mat / abs(np.max(total_mat))
# #
# image = img.squeeze(0) # .detach().numpy().astype(np.float32)
#
#
# plt.imshow(plt.imread(os.path.join(os.getcwd(), "dataset/root/train", image_filename))) # full path needed
# plt.imshow(total_mat, cmap='jet', alpha=0.3)
#
# # selected_frames.append()
# filename = frame_name+frame_no+video_no+".jpg"
# file_path = os.path.join(os.getcwd(), "dataset/gifs/raw/", filename)
# plt.savefig(file_path)
# selected_frames.append(file_path)
exit()
# plt.figure(figsize=(16, 16))
# for i in range(36):
# plt.subplot(6, 6, i + 1)
# plt.imshow(cv2.imread(TEST_IMAGE))
# plt.imshow(zoom(conv[0, :,:,i], zoom=(scale, scale)), cmap='jet', alpha=0.3)
#
# plt.show()