File size: 6,264 Bytes
ec5d79d 8cff122 ec5d79d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50, vit_b_32
from transformers import LevitForImageClassification, logging
logging.set_verbosity_error()
class PlantDiseaseClassifier:
class_names = ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy']
def __init__(self, model_type, model_path, batch_size=32):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.batch_size = batch_size
self.model_type = model_type
# Initialize and load the appropriate model
self.model = self._load_model(model_type, model_path)
self.model = self.model.to(self.device)
self.model.eval()
# Data transformation pipeline
self.data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def _load_model(self, model_type, model_path):
if model_type == "resnet":
model = resnet50(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, len(self.class_names))
model.load_state_dict(torch.load(model_path, map_location=self.device))
elif model_type == "levit":
model = LevitForImageClassification.from_pretrained(
"facebook/levit-128S", num_labels=len(self.class_names), ignore_mismatched_sizes=True,
)
state_dict = torch.load(model_path, map_location=self.device)
filtered_state_dict = {k: v for k, v in state_dict.items() if not k.startswith("classifier_distill")}
model.load_state_dict(filtered_state_dict)
elif model_type == "vit":
model = vit_b_32(pretrained=False, num_classes=len(self.class_names))
model.load_state_dict(torch.load(model_path, map_location=self.device))
else:
raise ValueError(f"Unsupported model type: {model_type}")
return model
class _PlantDiseaseDataset(Dataset):
def __init__(self, directory_path, transform=None):
self.directory_path = directory_path
self.transform = transform
# Collect all images and their respective class labels
self.image_files = []
self.labels = []
for class_name in os.listdir(directory_path):
class_dir = os.path.join(directory_path, class_name)
if os.path.isdir(class_dir) and class_name in PlantDiseaseClassifier.class_names:
for img_file in os.listdir(class_dir):
if img_file.lower().endswith(('.jpg', '.png')):
self.image_files.append(os.path.join(class_dir, img_file))
self.labels.append(PlantDiseaseClassifier.class_names.index(class_name))
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = self.image_files[idx]
image = Image.open(img_path).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
def calculate_accuracy(self, test_dir):
dataset = self._PlantDiseaseDataset(test_dir, transform=self.data_transforms)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in dataloader:
inputs = inputs.to(self.device)
labels = labels.to(self.device)
outputs = self.model(inputs)
logits = outputs.logits if self.model_type in ["levit"] else outputs
_, preds = torch.max(logits, 1)
correct += (preds == labels).sum().item()
total += labels.size(0)
accuracy = (correct / total) * 100 if total > 0 else 0.0
return accuracy
def predict(self, image):
# Ensure the image is in RGB format if not already
if image.mode != "RGB":
image = image.convert("RGB")
# Transform the image to match the model's input requirements
transformed_image = self.data_transforms(image).unsqueeze(0)
transformed_image = transformed_image.to(self.device)
# Make prediction
with torch.no_grad():
outputs = self.model(transformed_image)
logits = outputs.logits if self.model_type in ["levit"] else outputs
_, predicted_idx = torch.max(logits, 1)
predicted_class = self.class_names[predicted_idx.item()]
return predicted_class
def predict_image_with_all_models(image_path, classifiers):
actual_disease = os.path.basename(os.path.dirname(image_path))
print(f"Actual disease: {actual_disease}\n")
for model_name, classifier in classifiers.items():
predicted_class = classifier.predict_image(image_path)
print(f"Model: {model_name}, Predicted Class: {predicted_class}")
|