| | import torch |
| | import torch.nn as nn |
| | from PIL import Image |
| | from torch import Tensor |
| | from torchvision import transforms |
| |
|
| | from model import VGG16WithCNN |
| |
|
| |
|
| | def getModel(device: torch.device, model_path: str): |
| | model = VGG16WithCNN(5) |
| | |
| | model.load_state_dict( |
| | torch.load( |
| | model_path, |
| | weights_only=True, |
| | ) |
| | ) |
| |
|
| | model.to(device) |
| | return model |
| |
|
| |
|
| | def preprocess_image(image_path: str, image_size=(224, 224)): |
| | """ |
| | 预处理图片,使其符合模型输入要求 |
| | """ |
| | transform = transforms.Compose( |
| | [ |
| | transforms.Resize(image_size), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | ] |
| | ) |
| |
|
| | |
| | image: Image.Image = Image.open(image_path).convert("RGB") |
| | image_tensor: Tensor = transform(image) |
| | |
| | image_tensor = image_tensor.unsqueeze(0) |
| |
|
| | return image_tensor |
| |
|
| |
|
| | def predict_single_image( |
| | image_path: str, model: nn.Module, device: torch.device, class_names: list[str] |
| | ) -> str: |
| | """ |
| | 预测单个图片的标签 |
| | Args: |
| | image_path: 图片路径 |
| | model: 模型 |
| | device: 设备 |
| | |
| | Returns: |
| | 预测的标签名 |
| | """ |
| |
|
| | image_tensor = preprocess_image(image_path) |
| |
|
| | image_tensor = image_tensor.to(device) |
| | |
| | model.eval() |
| | with torch.no_grad(): |
| | output = model(image_tensor) |
| | _, pred = torch.max(output, 1) |
| |
|
| | predicted_label = class_names[int(pred.item())] |
| |
|
| | return predicted_label |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| |
|
| | p = "./checkpoints/vgg_net_model_50.pth" |
| |
|
| | class_names = [ |
| | "Bacterialblight", |
| | "Blast", |
| | "Brownspot", |
| | "Healthy", |
| | "Tungro", |
| | ] |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model = getModel(device=device, model_path=p) |
| |
|
| | test_image_path = "./images/BLAST1_011.jpg" |
| | try: |
| | predicted_label = predict_single_image( |
| | test_image_path, model, device, class_names=class_names |
| | ) |
| | print("\nSingle image prediction result:") |
| | print(f"Image: {test_image_path}") |
| | print(f"Predicted label: {predicted_label}") |
| | except FileNotFoundError: |
| | print("Please provide a valid image path to test single image prediction") |
| |
|