Spaces:
Runtime error
Runtime error
import pytorch_lightning as pl | |
import torch | |
import torch.nn as nn | |
import utils | |
from torchvision.models import resnet50 | |
import torch | |
from monai.transforms import ( | |
Compose, Resize, ResizeWithPadOrCrop, | |
) | |
from pytorch_grad_cam import GradCAM | |
import matplotlib.colors as mcolors | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
from io import BytesIO | |
class ResNet(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.save_hyperparameters() | |
backbone = resnet50() | |
num_input_channel = 1 | |
layer = backbone.conv1 | |
new_layer = nn.Conv2d( | |
in_channels=num_input_channel, | |
out_channels=layer.out_channels, | |
kernel_size=layer.kernel_size, | |
stride=layer.stride, | |
padding=layer.padding, | |
bias=layer.bias, | |
) | |
new_layer.weight = nn.Parameter(layer.weight.sum(dim=1, keepdim=True)) | |
backbone.conv1 = new_layer | |
backbone.fc = nn.Sequential( | |
nn.Linear(2048, 1024), | |
nn.ReLU(), | |
nn.BatchNorm1d(1024), | |
nn.Dropout(0), | |
nn.Linear(1024, 2), | |
) | |
self.model = backbone | |
def forward(self, x): | |
out = self.model(x) | |
return out | |
val_transforms_416x628 = Compose( | |
[ | |
utils.CustomCLAHE(), | |
Resize(spatial_size=628, mode="bilinear", align_corners=True, size_mode="longest"), | |
ResizeWithPadOrCrop(spatial_size=(416, 628)), | |
] | |
) | |
checkpoint = torch.load("classification_model.ckpt", map_location=torch.device('cpu')) | |
model = ResNet() | |
model.load_state_dict(checkpoint["state_dict"]) | |
model.eval() | |
def load_and_classify_image(image_path, device): | |
model = model.to(device) | |
image = val_transforms_416x628(image_path) | |
image = image.unsqueeze(0).to(device) | |
with torch.no_grad(): | |
prediction = model(image) | |
prediction = torch.nn.functional.softmax(prediction, dim=1).squeeze(0) | |
return prediction.to('cpu'), image.to('cpu') | |
def make_GradCAM(image, device): | |
model = model.to(device) | |
image = image.to(device) | |
model.eval() | |
target_layers = [model.model.layer4[-1]] | |
arr = image.numpy().squeeze() | |
cam = GradCAM(model=model, target_layers=target_layers) | |
targets = None | |
grayscale_cam = cam( | |
input_tensor=image, | |
targets=targets, | |
aug_smooth=False, | |
eigen_smooth=True, | |
) | |
grayscale_cam = grayscale_cam.to('cpu').squeeze() | |
jet = plt.colormaps.get_cmap("inferno") | |
newcolors = jet(np.linspace(0, 1, 256)) | |
newcolors[0, :3] = 0 | |
new_jet = mcolors.ListedColormap(newcolors) | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(arr, cmap='gray') | |
plt.imshow(grayscale_cam, cmap=new_jet, alpha=0.5) | |
plt.axis('off') | |
buffer2 = BytesIO() | |
plt.savefig(buffer2, format='png', bbox_inches='tight', pad_inches=0) | |
buffer2.seek(0) | |
gradcam_image = np.array(Image.open(buffer2)).squeeze() | |
return gradcam_image | |