import torch import torch.nn as nn from torchvision import transforms from PIL import Image from torchvision.models import resnet18 class ResNet18Classifier(nn.Module): def __init__(self, num_classes=3): super().__init__() self.resnet = resnet18(weights=None) # modern way self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes) def forward(self, x): return self.resnet(x) def load_model(model_path="model/best_classification_model.pth", num_classes=3): model = ResNet18Classifier(num_classes=num_classes) state_dict = torch.load(model_path, map_location='cpu') model.load_state_dict(state_dict) model.eval() return model def predict_image(image_path, model, class_names): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = Image.open(image_path).convert('RGB') image_tensor = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(image_tensor) _, predicted = torch.max(outputs, 1) return class_names[predicted.item()]