File size: 6,264 Bytes
ec5d79d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cff122
 
 
 
 
 
 
 
 
 
ec5d79d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, vit_b_32
from transformers import LevitForImageClassification, logging

logging.set_verbosity_error()

class PlantDiseaseClassifier:
    class_names = ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy']
    def __init__(self, model_type, model_path, batch_size=32):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.batch_size = batch_size
        self.model_type = model_type

        # Initialize and load the appropriate model
        self.model = self._load_model(model_type, model_path)
        self.model = self.model.to(self.device)
        self.model.eval()

        # Data transformation pipeline
        self.data_transforms = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def _load_model(self, model_type, model_path):
        if model_type == "resnet":
            model = resnet50(pretrained=False)
            model.fc = torch.nn.Linear(model.fc.in_features, len(self.class_names))
            model.load_state_dict(torch.load(model_path, map_location=self.device))
        elif model_type == "levit":
            model = LevitForImageClassification.from_pretrained(
                "facebook/levit-128S", num_labels=len(self.class_names), ignore_mismatched_sizes=True,
            )
            state_dict = torch.load(model_path, map_location=self.device)
            filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("classifier_distill")}
            model.load_state_dict(filtered_state_dict)
        elif model_type == "vit":
            model = vit_b_32(pretrained=False, num_classes=len(self.class_names))
            model.load_state_dict(torch.load(model_path, map_location=self.device))
        else:
            raise ValueError(f"Unsupported model type: {model_type}")
        return model

    class _PlantDiseaseDataset(Dataset):
        def __init__(self, directory_path, transform=None):
            self.directory_path = directory_path
            self.transform = transform

            # Collect all images and their respective class labels
            self.image_files = []
            self.labels = []
            for class_name in os.listdir(directory_path):
                class_dir = os.path.join(directory_path, class_name)
                if os.path.isdir(class_dir) and class_name in PlantDiseaseClassifier.class_names:
                    for img_file in os.listdir(class_dir):
                        if img_file.lower().endswith(('.jpg', '.png')):
                            self.image_files.append(os.path.join(class_dir, img_file))
                            self.labels.append(PlantDiseaseClassifier.class_names.index(class_name))

        def __len__(self):
            return len(self.image_files)

        def __getitem__(self, idx):
            img_path = self.image_files[idx]
            image = Image.open(img_path).convert('RGB')
            label = self.labels[idx]
            if self.transform:
                image = self.transform(image)
            return image, label

    def calculate_accuracy(self, test_dir):
        dataset = self._PlantDiseaseDataset(test_dir, transform=self.data_transforms)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)

        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs = self.model(inputs)
                logits = outputs.logits if self.model_type in ["levit"] else outputs
                _, preds = torch.max(logits, 1)

                correct += (preds == labels).sum().item()
                total += labels.size(0)

        accuracy = (correct / total) * 100 if total > 0 else 0.0
        return accuracy

    def predict(self, image):
    # Ensure the image is in RGB format if not already
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        # Transform the image to match the model's input requirements
        transformed_image = self.data_transforms(image).unsqueeze(0)
        transformed_image = transformed_image.to(self.device)
        
        # Make prediction
        with torch.no_grad():
            outputs = self.model(transformed_image)
            logits = outputs.logits if self.model_type in ["levit"] else outputs
            _, predicted_idx = torch.max(logits, 1)

        predicted_class = self.class_names[predicted_idx.item()]
        return predicted_class

def predict_image_with_all_models(image_path, classifiers):
    actual_disease = os.path.basename(os.path.dirname(image_path))
    print(f"Actual disease: {actual_disease}\n")
    for model_name, classifier in classifiers.items():
        predicted_class = classifier.predict_image(image_path)
        print(f"Model: {model_name}, Predicted Class: {predicted_class}")