food_detector / image_classifier.py
kevin510's picture
initial commit
32f1e6c
import torch
from torchvision import models, transforms
from PIL import Image
CLASS_NAMES = ['apple', 'bread', 'fried_chicken', 'hamburger', 'pizza', 'popcorn', 'salad', 'steak', 'taco']
class ImageClassifier:
def __init__(self, model_path, device='cpu'):
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
# Adjust the last layer to match the number of classes
num_ftrs = self.model.fc.in_features
self.model.fc = torch.nn.Linear(num_ftrs, len(CLASS_NAMES))
# Load the saved model
self.model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
self.model.eval() # Set the model to evaluation mode
def classify_image(self, image):
image = self.transform(image).unsqueeze(0) # Add batch dimension
# Perform inference
with torch.no_grad():
output = self.model(image)
_, predicted = torch.max(output, 1)
return CLASS_NAMES[predicted.item()]