ImageClassification / image_classification_api.py
DanielIglesias97's picture
First upload of the code for the ImageClassification demo.
696ce04
import configparser
import pandas as pd
from PIL import Image
import torch
from torchvision import models, transforms
# This is an abstract class. The method "get_model" must be implemented
# by the child class.
class ImageClassifierBase():
def __init__(self):
pass
# self.logger = logging.getLogger(__name__)
# logging.basicConfig(filename='app.log', level=logging.INFO)
def __read_text_labels__(self):
text_labels = pd.read_csv('imagenet_labels.csv').values
text_labels = text_labels.flatten()
return text_labels
def __read_image__(self, device, image):
preprocess = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
input_batch = input_batch.to(device)
return input_batch
def get_device(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return device
# This method must be implemented by the child class.
def get_model(self, device):
pass
def image_classification(self, model, input_image, device, hard_detection_threshold=0.0):
input_batch = self.__read_image__(device, input_image)
with torch.no_grad():
output = model(input_batch).data
text_labels = self.__read_text_labels__()
classification_summary = pd.DataFrame()
classification_summary['label'] = text_labels
classification_summary['prob'] = output[0]
classification_summary = \
classification_summary.sort_values(by=['prob'], ascending=False)
return classification_summary
class ImageClassifierVGG16(ImageClassifierBase):
def __init__(self):
super().__init__()
def get_model(self, device):
model = models.vgg16(pretrained=True)
model.to(device)
model.eval()
return model