LUWA / utils /util_function.py
DanielXu0208's picture
Initial commit
785ef2b
raw
history blame
7.61 kB
import cv2
from sklearn.manifold import TSNE
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn import decomposition
import itertools
def normalize_image(image):
image_min = image.min()
image_max = image.max()
image.clamp_(min = image_min, max = image_max)
image.add_(-image_min).div_(image_max - image_min + 1e-5)
return image
def plot_lr_finder(fig_name, lrs, losses, skip_start=5, skip_end=5):
if skip_end == 0:
lrs = lrs[skip_start:]
losses = losses[skip_start:]
else:
lrs = lrs[skip_start:-skip_end]
losses = losses[skip_start:-skip_end]
fig = plt.figure(figsize=(16, 8))
ax = fig.add_subplot(1, 1, 1)
ax.plot(lrs, losses)
ax.set_xscale('log')
ax.set_xlabel('Learning rate')
ax.set_ylabel('Loss')
ax.grid(True, 'both', 'x')
plt.show()
plt.savefig(fig_name)
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
def plot_confusion_matrix(fig_name, labels, pred_labels, classes):
fig = plt.figure(figsize=(50, 50));
ax = fig.add_subplot(1, 1, 1);
cm = confusion_matrix(labels, pred_labels);
cm = ConfusionMatrixDisplay(cm, display_labels=classes);
cm.plot(values_format='d', cmap='Blues', ax=ax)
fig.delaxes(fig.axes[1]) # delete colorbar
plt.xticks(rotation=90, fontsize=50)
plt.yticks(fontsize=50)
plt.rcParams.update({'font.size': 50})
plt.xlabel('Predicted Label', fontsize=50)
plt.ylabel('True Label', fontsize=50)
plt.savefig(fig_name)
def plot_confusion_matrix_SVM(fig_name, true_labels, predicted_labels, classes):
fig = plt.figure(figsize=(100, 100))
ax = fig.add_subplot(1, 1, 1)
cm = confusion_matrix(true_labels, predicted_labels)
cm_display = ConfusionMatrixDisplay(cm, display_labels=classes)
cm_display.plot(values_format='d', cmap='Blues', ax=ax)
fig.delaxes(fig.axes[1]) # delete colorbar
plt.xticks(rotation=90, fontsize=50)
plt.yticks(fontsize=50)
plt.rcParams.update({'font.size': 50})
plt.xlabel('Predicted Label', fontsize=50)
plt.ylabel('True Label', fontsize=50)
plt.savefig(fig_name)
def plot_most_incorrect(fig_name, incorrect, classes, n_images, normalize=True):
rows = int(np.sqrt(n_images))
cols = int(np.sqrt(n_images))
fig = plt.figure(figsize=(25, 20))
for i in range(rows * cols):
ax = fig.add_subplot(rows, cols, i + 1)
image, true_label, probs = incorrect[i]
image = image.permute(1, 2, 0)
true_prob = probs[true_label]
incorrect_prob, incorrect_label = torch.max(probs, dim=0)
true_class = classes[true_label]
incorrect_class = classes[incorrect_label]
if normalize:
image = normalize_image(image)
ax.imshow(image.cpu().numpy())
ax.set_title(f'true label: {true_class} ({true_prob:.3f})\n' \
f'pred label: {incorrect_class} ({incorrect_prob:.3f})')
ax.axis('off')
fig.subplots_adjust(hspace=0.4)
plt.savefig(fig_name)
def get_pca(data, n_components = 2):
pca = decomposition.PCA()
pca.n_components = n_components
pca_data = pca.fit_transform(data)
return pca_data
def plot_representations(fig_name, data, labels, classes, n_images=None):
if n_images is not None:
data = data[:n_images]
labels = labels[:n_images]
fig = plt.figure(figsize=(15, 15))
ax = fig.add_subplot(111)
scatter = ax.scatter(data[:, 0], data[:, 1], c=labels, cmap='hsv')
# handles, _ = scatter.legend_elements(num = None)
# legend = plt.legend(handles = handles, labels = classes)
plt.savefig(fig_name)
def plot_filtered_images(fig_name, images, filters, n_filters = None, normalize = True):
images = torch.cat([i.unsqueeze(0) for i in images], dim = 0).cpu()
filters = filters.cpu()
if n_filters is not None:
filters = filters[:n_filters]
n_images = images.shape[0]
n_filters = filters.shape[0]
filtered_images = F.conv2d(images, filters)
fig = plt.figure(figsize = (30, 30))
for i in range(n_images):
image = images[i]
if normalize:
image = normalize_image(image)
ax = fig.add_subplot(n_images, n_filters+1, i+1+(i*n_filters))
ax.imshow(image.permute(1,2,0).numpy())
ax.set_title('Original')
ax.axis('off')
for j in range(n_filters):
image = filtered_images[i][j]
if normalize:
image = normalize_image(image)
ax = fig.add_subplot(n_images, n_filters+1, i+1+(i*n_filters)+j+1)
ax.imshow(image.numpy(), cmap = 'bone')
ax.set_title(f'Filter {j+1}')
ax.axis('off');
fig.subplots_adjust(hspace = -0.7)
plt.savefig(fig_name)
def plot_filters(fig_name, filters, normalize=True):
filters = filters.cpu()
n_filters = filters.shape[0]
rows = int(np.sqrt(n_filters))
cols = int(np.sqrt(n_filters))
fig = plt.figure(figsize=(30, 15))
for i in range(rows * cols):
image = filters[i]
if normalize:
image = normalize_image(image)
ax = fig.add_subplot(rows, cols, i + 1)
ax.imshow(image.permute(1, 2, 0))
ax.axis('off')
fig.subplots_adjust(wspace=-0.9)
plt.savefig(fig_name)
def plot_tsne(fig_name, all_features, all_labels):
tsne = TSNE(n_components=2, random_state=42)
tsne_results = tsne.fit_transform(all_features)
plt.figure(figsize=(10, 7))
scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=all_labels, cmap='viridis', s=5)
plt.colorbar(scatter)
plt.title('t-SNE Visualization')
plt.show()
plt.savefig(fig_name)
def plot_grad_cam(images, cams, predicted_labels, true_labels, classes, path):
fig, axs = plt.subplots(nrows=2, ncols=len(images), figsize=(20, 10))
for i, (img, cam, pred_label, true_label) in enumerate(zip(images, cams, predicted_labels, true_labels)):
# Display the original image on the top row
axs[0, i].imshow(img.permute(1,2,0).cpu().numpy())
pred_class_name = classes[pred_label]
true_class_name = classes[true_label]
axs[0, i].set_title(f"Predicted: {pred_class_name}\nTrue: {true_class_name}", fontsize=12)
axs[0, i].axis('off')
# Add label to the leftmost plot
if i == 0:
axs[0, i].set_ylabel("Original Image", fontsize=14, rotation=90, labelpad=10)
# Convert the original image to grayscale
grayscale_img = cv2.cvtColor(img.permute(1,2,0).cpu().numpy(), cv2.COLOR_RGB2GRAY)
grayscale_img = cv2.cvtColor(grayscale_img, cv2.COLOR_GRAY2RGB)
# Overlay the Grad-CAM heatmap on the grayscale image
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam_img = heatmap + np.float32(grayscale_img)
cam_img = cam_img / np.max(cam_img)
# Display the Grad-CAM image on the bottom row
axs[1, i].imshow(cam_img)
axs[1, i].axis('off')
# Add label to the leftmost plot
if i == 0:
axs[1, i].set_ylabel("Grad-CAM", fontsize=14, rotation=90, labelpad=10)
plt.tight_layout()
plt.savefig(path)
plt.close()