wgetdd's picture
Update utils.py
165456f
raw
history blame contribute delete
No virus
6.86 kB
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import transforms
import cv2
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import matplotlib.pyplot as plt
import textwrap
import io
def apply_normalization(chennels):
return nn.BatchNorm2d(chennels)
class CustomResnet(nn.Module):
def __init__(self):
super(CustomResnet, self).__init__()
# Input Block
drop = 0.0
# PrepLayer - Conv 3x3 s1, p1) >> BN >> RELU [64k]
self.preplayer = nn.Sequential(
nn.Conv2d(3, 64, (3, 3), padding=1, stride=1, bias=False), # 3
apply_normalization(64),
nn.ReLU(),
)
# Layer1 -
# X = Conv 3x3 (s1, p1) >> MaxPool2D >> BN >> RELU [128k]
self.convlayer1 = nn.Sequential(
nn.Conv2d(64, 128, (3, 3), padding=1, stride=1, bias=False), # 3
nn.MaxPool2d(2, 2),
apply_normalization(128),
nn.ReLU(),
)
# R1 = ResBlock( (Conv-BN-ReLU-Conv-BN-ReLU))(X) [128k]
self.reslayer1 = nn.Sequential(
nn.Conv2d(128, 128, (3, 3), padding=1, stride=1, bias=False), # 3
apply_normalization(128),
nn.ReLU(),
nn.Conv2d(128, 128, (3, 3), padding=1, stride=1, bias=False), # 3
apply_normalization(128),
nn.ReLU(),
)
# Conv 3x3 [256k]
self.convlayer2 = nn.Sequential(
nn.Conv2d(128, 256, (3, 3), padding=1, stride=1, bias=False), # 3
nn.MaxPool2d(2, 2),
apply_normalization(256),
nn.ReLU(),
)
# X = Conv 3x3 (s1, p1) >> MaxPool2D >> BN >> RELU [512k]
self.convlayer3 = nn.Sequential(
nn.Conv2d(256, 512, (3, 3), padding=1, stride=1, bias=False), # 3
nn.MaxPool2d(2, 2),
apply_normalization(512),
nn.ReLU(),
)
# R1 = ResBlock( (Conv-BN-ReLU-Conv-BN-ReLU))(X) [128k]
self.reslayer2 = nn.Sequential(
nn.Conv2d(512, 512, (3, 3), padding=1, stride=1, bias=False), # 3
apply_normalization(512),
nn.ReLU(),
nn.Conv2d(512, 512, (3, 3), padding=1, stride=1, bias=False), # 3
apply_normalization(512),
nn.ReLU(),
)
self.maxpool3 = nn.MaxPool2d(4, 2)
self.linear1 = nn.Linear(512,10)
def forward(self,x):
x = self.preplayer(x)
x1 = self.convlayer1(x)
x2 = self.reslayer1(x1)
x = x1+x2
x = self.convlayer2(x)
x = self.convlayer3(x)
x1 = self.reslayer2(x)
x = x+x1
x = self.maxpool3(x)
x = x.view(-1, 512)
x = self.linear1(x)
return F.log_softmax(x, dim=-1)
def resize_image(image, target_size=(200, 200)):
return cv2.resize(image, target_size)
def wrap_text(text, width=20):
return textwrap.fill(text, width)
import io
# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def save_plot_as_image(images,texts, output_path):
num_images = len(images)
num_cols = min(4, num_images) # Assuming you want a maximum of 4 columns
num_rows = (num_images - 1) // num_cols + 1
fig, axes = plt.subplots(num_rows, num_cols, figsize=(3 * num_cols, 3 * num_rows))
subplot_height = 0.9 / num_rows # Adjust this value to control the height of each subplot
plt.subplots_adjust(hspace=subplot_height)
for i, ax in enumerate(axes.flat):
if i < num_images:
ax.imshow(images[i], cmap='gray')
ax.axis('off')
if texts is not None and i < len(texts):
wrapped_text = wrap_text(texts[i])
ax.set_title(wrapped_text, fontsize=12, pad=5)
else:
ax.axis('off')
plt.tight_layout()
# plt.savefig("tmp_arrays.png")
# plt.close()
return get_img_from_fig(plt)
# Function to run inference and return top classes
def get_gradcam(model,input_img, opacity,layer):
targets = None
inv_normalize = transforms.Normalize(
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
std=[1/0.23, 1/0.23, 1/0.23]
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.ToTensor()
input_img = transform(input_img)
input_img = input_img.to(device)
input_img = input_img.unsqueeze(0)
outputs = model(input_img)
if layer == "convblock1":
target_layers = model.convlayer1
elif layer == "convblock2":
target_layers = model.convlayer2
elif layer == "resblock1":
target_layers = model.reslayer1
elif layer == "resblock2":
target_layers = model.reslayer2
elif layer == "convblock3":
target_layers = model.convlayer3
layer_to_user = []
for i in target_layers:
if str(i) != "ReLU()":
layer_to_user.append(i)
print(layer_to_user)
final_outputs,texts = [],[]
for i in range(len(layer_to_user)):
cam = GradCAM(model=model, target_layers=[layer_to_user[i]], use_cuda=False)
grayscale_cam = cam(input_tensor=input_img, targets=targets)
grayscale_cam = grayscale_cam[0, :]
img = input_img.squeeze(0).to('cpu')
img = inv_normalize(img)
rgb_img = np.transpose(img, (1, 2, 0))
rgb_img = rgb_img.numpy()
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=opacity)
final_outputs.append(resize_image(visualization))
texts.append(str(layer_to_user[i]))
figure = save_plot_as_image(final_outputs,texts, "plot.png")
return figure
def get_misclassified_images(show_misclassified,num):
if show_misclassified:
return cv2.imread(f"missclassified_images_examples/{int(num)}.png")
else:
return None
def main_inference(num_of_output_classes,classes,model,input_img):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.ToTensor()
input_img = transform(input_img)
input_img = input_img.to(device)
input_img = input_img.unsqueeze(0)
softmax = torch.nn.Softmax(dim=0)
outputs = model(input_img)
out = softmax(outputs.flatten())
_, prediction = torch.max(outputs, 1)
confidences = {classes[i]:float(out[i]) for i in range(num_of_output_classes)}
outputs = model(input_img)
_, prediction = torch.max(outputs, 1)
return confidences