Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| from torchvision.models import resnet18 | |
| class ResNet18Classifier(nn.Module): | |
| def __init__(self, num_classes=3): | |
| super().__init__() | |
| self.resnet = resnet18(weights=None) # modern way | |
| self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes) | |
| def forward(self, x): | |
| return self.resnet(x) | |
| def load_model(model_path="model/best_classification_model.pth", num_classes=3): | |
| model = ResNet18Classifier(num_classes=num_classes) | |
| state_dict = torch.load(model_path, map_location='cpu') | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| def predict_image(image_path, model, class_names): | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = Image.open(image_path).convert('RGB') | |
| image_tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| _, predicted = torch.max(outputs, 1) | |
| return class_names[predicted.item()] | |