trash-ai / model.py
kndlr
Initial commit
89535c7
raw
history blame contribute delete
No virus
3.49 kB
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
import numpy as np
import time
class Prediction:
def __init__(self, data, heatmap, duration):
self.data = data
self.heatmap = heatmap
self.duration = duration
class Pipeline:
def __init__(self):
self.classes = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
self.transformations = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
if torch.cuda.is_available():
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')
self.model = {}
resnet50_0810 = self.to_device(ResNet50(self.classes), self.device)
resnet50_0810.load_state_dict(torch.load('models/resnet50_0810.pt', map_location=self.device))
resnet50_0810.eval()
resnet50_0810.cam = GradCAM(resnet50_0810.network, [resnet50_0810.network.layer4], torch.cuda.is_available())
self.model["resnet50_0810"] = resnet50_0810
resnet152_0813 = self.to_device(ResNet152(self.classes), self.device)
resnet152_0813.load_state_dict(torch.load('models/resnet152_0813.pt', map_location=self.device))
resnet152_0813.eval()
resnet152_0813.cam = GradCAM(resnet152_0813.network, [resnet152_0813.network.layer4], torch.cuda.is_available())
self.model["resnet152_0813"] = resnet152_0813
resnet152_0902 = self.to_device(ResNet152(self.classes), self.device)
resnet152_0902.load_state_dict(torch.load('models/resnet152_0902.pt', map_location=self.device))
resnet152_0902.eval()
resnet152_0902.cam = GradCAM(resnet152_0902.network, [resnet152_0902.network.layer4], torch.cuda.is_available())
self.model["resnet152_0902"] = resnet152_0902
def to_device(self, data, device):
return data.to(device, torch.float32)
def predict_image(self, model, image):
tensor = self.transformations(image)
xb = self.to_device(tensor.unsqueeze(0), self.device)
start_time = time.time()
yb = self.model[model](xb)
end_time = time.time()
data = {self.classes[i]: float(yb[0][i]) for i in range(len(self.classes))}
return Prediction(data, self.visualize(model, image, xb), int((end_time - start_time) * 1000))
def visualize(self, model, rgb_image, input_tensor):
rgb_image = rgb_image.resize((256, 256))
rgb_image = np.array(rgb_image)
rgb_image = np.float32(rgb_image) / 255
greyscale_cam = self.model[model].cam(input_tensor)[0, :]
image = show_cam_on_image(rgb_image, greyscale_cam, use_rgb=True)
return Image.fromarray(image)
class ResNet50(torch.nn.Module):
def __init__(self, classes):
super().__init__()
self.network = models.resnet50(weights="DEFAULT")
self.network.fc = torch.nn.Linear(self.network.fc.in_features, len(classes))
def forward(self, xb):
return torch.sigmoid(self.network(xb))
class ResNet152(torch.nn.Module):
def __init__(self, classes):
super().__init__()
self.network = models.resnet152(weights="DEFAULT")
self.network.fc = torch.nn.Linear(self.network.fc.in_features, len(classes))
def forward(self, xb):
return torch.sigmoid(self.network(xb))