Spaces:
Sleeping
Sleeping
import torch | |
from torchvision import models, transforms | |
from PIL import Image | |
CLASS_NAMES = ['apple', 'bread', 'fried_chicken', 'hamburger', 'pizza', 'popcorn', 'salad', 'steak', 'taco'] | |
class ImageClassifier: | |
def __init__(self, model_path, device='cpu'): | |
self.transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
self.model = models.resnet50(weights='ResNet50_Weights.DEFAULT') | |
# Adjust the last layer to match the number of classes | |
num_ftrs = self.model.fc.in_features | |
self.model.fc = torch.nn.Linear(num_ftrs, len(CLASS_NAMES)) | |
# Load the saved model | |
self.model.load_state_dict(torch.load(model_path, map_location=torch.device(device))) | |
self.model.eval() # Set the model to evaluation mode | |
def classify_image(self, image): | |
image = self.transform(image).unsqueeze(0) # Add batch dimension | |
# Perform inference | |
with torch.no_grad(): | |
output = self.model(image) | |
_, predicted = torch.max(output, 1) | |
return CLASS_NAMES[predicted.item()] | |