|
import torch |
|
import torchvision.transforms as transforms |
|
import torchvision.models as models |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
from PIL import Image |
|
import numpy as np |
|
import time |
|
|
|
class Prediction: |
|
def __init__(self, data, heatmap, duration): |
|
self.data = data |
|
self.heatmap = heatmap |
|
self.duration = duration |
|
|
|
class Pipeline: |
|
def __init__(self): |
|
self.classes = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash'] |
|
self.transformations = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]) |
|
|
|
if torch.cuda.is_available(): |
|
self.device = torch.device('cuda') |
|
else: |
|
self.device = torch.device('cpu') |
|
|
|
self.model = {} |
|
|
|
resnet50_0810 = self.to_device(ResNet50(self.classes), self.device) |
|
resnet50_0810.load_state_dict(torch.load('models/resnet50_0810.pt', map_location=self.device)) |
|
resnet50_0810.eval() |
|
resnet50_0810.cam = GradCAM(resnet50_0810.network, [resnet50_0810.network.layer4], torch.cuda.is_available()) |
|
self.model["resnet50_0810"] = resnet50_0810 |
|
|
|
resnet152_0813 = self.to_device(ResNet152(self.classes), self.device) |
|
resnet152_0813.load_state_dict(torch.load('models/resnet152_0813.pt', map_location=self.device)) |
|
resnet152_0813.eval() |
|
resnet152_0813.cam = GradCAM(resnet152_0813.network, [resnet152_0813.network.layer4], torch.cuda.is_available()) |
|
self.model["resnet152_0813"] = resnet152_0813 |
|
|
|
resnet152_0902 = self.to_device(ResNet152(self.classes), self.device) |
|
resnet152_0902.load_state_dict(torch.load('models/resnet152_0902.pt', map_location=self.device)) |
|
resnet152_0902.eval() |
|
resnet152_0902.cam = GradCAM(resnet152_0902.network, [resnet152_0902.network.layer4], torch.cuda.is_available()) |
|
self.model["resnet152_0902"] = resnet152_0902 |
|
|
|
def to_device(self, data, device): |
|
return data.to(device, torch.float32) |
|
|
|
def predict_image(self, model, image): |
|
tensor = self.transformations(image) |
|
xb = self.to_device(tensor.unsqueeze(0), self.device) |
|
start_time = time.time() |
|
yb = self.model[model](xb) |
|
end_time = time.time() |
|
data = {self.classes[i]: float(yb[0][i]) for i in range(len(self.classes))} |
|
return Prediction(data, self.visualize(model, image, xb), int((end_time - start_time) * 1000)) |
|
|
|
def visualize(self, model, rgb_image, input_tensor): |
|
rgb_image = rgb_image.resize((256, 256)) |
|
rgb_image = np.array(rgb_image) |
|
rgb_image = np.float32(rgb_image) / 255 |
|
greyscale_cam = self.model[model].cam(input_tensor)[0, :] |
|
image = show_cam_on_image(rgb_image, greyscale_cam, use_rgb=True) |
|
return Image.fromarray(image) |
|
|
|
class ResNet50(torch.nn.Module): |
|
def __init__(self, classes): |
|
super().__init__() |
|
self.network = models.resnet50(weights="DEFAULT") |
|
self.network.fc = torch.nn.Linear(self.network.fc.in_features, len(classes)) |
|
|
|
def forward(self, xb): |
|
return torch.sigmoid(self.network(xb)) |
|
|
|
class ResNet152(torch.nn.Module): |
|
def __init__(self, classes): |
|
super().__init__() |
|
self.network = models.resnet152(weights="DEFAULT") |
|
self.network.fc = torch.nn.Linear(self.network.fc.in_features, len(classes)) |
|
|
|
def forward(self, xb): |
|
return torch.sigmoid(self.network(xb)) |