|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import model_builder |
|
import matplotlib.pyplot as plt |
|
import argparse |
|
|
|
def predict_image(image_path: str, |
|
model_path: str, |
|
class_names: list): |
|
""" |
|
Predict class label of an image using a pre-trained model and plot the image with predicted class and probability as title. |
|
|
|
Args: |
|
image_path (str): Path to the input image. |
|
model_path (str): Path to the saved PyTorch model. |
|
class_names (list): List of class names. |
|
|
|
Returns: |
|
predicted_class (str): Predicted class label. |
|
probability_percentage (float): Probability percentage of the predicted class. |
|
""" |
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((112, 112)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
|
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
model = model_builder.TrashClassificationCNNModel(input_shape=3, |
|
hidden_units=15, |
|
output_shape=len(class_names) |
|
) |
|
|
|
model.load_state_dict(torch.load(model_path)) |
|
model.eval() |
|
|
|
|
|
with torch.inference_mode(): |
|
output = model(image_tensor) |
|
probabilities = torch.softmax(output, dim=1).squeeze().tolist() |
|
predicted_index = torch.argmax(output, 1).item() |
|
probability_percentage = probabilities[predicted_index] * 100 |
|
predicted_class = class_names[predicted_index] |
|
|
|
|
|
plt.imshow(image_tensor.squeeze().permute(1, 2, 0)) |
|
plt.title(f'Predicted Class: {predicted_class} | Probability: {probability_percentage:.2f}%', fontdict={'family': 'serif', |
|
'color': 'black', |
|
'weight': 'normal', |
|
'size': 16,}) |
|
plt.axis(False) |
|
plt.show() |
|
|
|
return predicted_class, probability_percentage |
|
|
|
|
|
parser = argparse.ArgumentParser(description="For Prediction of an Image using a Loaded Model") |
|
parser.add_argument("--image", type=str, default=None, help="Path to the image to make predictions on.") |
|
parser.add_argument("--model_path", type=str, default=None, help="Path to the Trained Models State Dict.") |
|
args = parser.parse_args() |
|
|
|
IMAGE_PATH = args.image |
|
MODEL_PATH = args.model_path |
|
|
|
|
|
if not IMAGE_PATH: |
|
print("Please Enter Image Path using --image") |
|
raise SystemExit |
|
|
|
if not MODEL_PATH: |
|
print("Please Enter Model PAth using --model_path") |
|
raise SystemExit |
|
|
|
|
|
try: |
|
|
|
predict_image(image_path=IMAGE_PATH, |
|
model_path=MODEL_PATH, |
|
class_names=['Cardboard', |
|
'Food Organics', |
|
'Glass', |
|
'Metal', |
|
'Miscellaneous Trash', |
|
'Paper', |
|
'Plastic', |
|
'Textile Trash', |
|
'Vegetation']) |
|
except Exception as exception: |
|
print("INVALID MODEL_PATH OR INVALID IMAGE PATH") |
|
print(f"\n{exception}") |