|
|
|
import torch |
|
import json |
|
from torchvision import transforms |
|
|
|
with open('label_mapping.json', 'r') as json_file: |
|
label_mapping = json.load(json_file) |
|
|
|
def load_model(path): |
|
model = torch.jit.load(path, map_location=torch.device("cpu")) |
|
return model |
|
|
|
def predict(model, image): |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) |
|
|
|
image = transform(image) |
|
|
|
with torch.no_grad(): |
|
image = image.unsqueeze(0) |
|
output = model(image) |
|
probabilities = torch.nn.functional.softmax(output, dim=1) |
|
_, predicted_class = torch.max(probabilities, 1) |
|
|
|
|
|
predicted_label = label_mapping[f"{predicted_class.item()}"] |
|
probability= probabilities[0][predicted_class].item() |
|
|
|
return predicted_label, round(probability, 2) |