import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms import timm from transformers import ViTFeatureExtractor, ViTForImageClassification from pathlib import Path import pandas as pd import numpy as np from PIL import Image from sklearn.model_selection import train_test_split from tqdm.auto import tqdm import wandb
class PlantDiseaseDataset(Dataset): def init(self, image_paths, labels, transform=None): self.image_paths = image_paths self.labels = labels self.transform = transform def len(self): return len(self.image_paths) def getitem(self, idx): image_path = self.image_paths[idx] image = Image.open(image_path).convert('RGB') label = self.labels[idx] if self.transform: image = self.transform(image) return image, label
class PlantDiseaseClassifier: def init(self, data_dir, model_name='vit_base_patch16_224', num_classes=38): self.data_dir = Path(data_dir) self.model_name = model_name self.num_classes = num_classes self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Initialize wandb wandb.init(project="plant-disease-classification") def prepare_data(self): """Prepare dataset and create data loaders""" # Data augmentation and normalization for training train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(20), transforms.ColorJitter(brightness=0.2, contrast=0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Just normalization for validation/testing val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Collect all image paths and labels image_paths = [] labels = [] self.class_to_idx = {} for idx, class_dir in enumerate(sorted(self.data_dir.glob(''))): if class_dir.is_dir(): self.class_to_idx[class_dir.name] = idx for img_path in class_dir.glob('.jpg'): image_paths.append(str(img_path)) labels.append(idx) # Split data train_paths, val_paths, train_labels, val_labels = train_test_split( image_paths, labels, test_size=0.2, stratify=labels, random_state=42 ) # Create datasets train_dataset = PlantDiseaseDataset(train_paths, train_labels, train_transform) val_dataset = PlantDiseaseDataset(val_paths, val_labels, val_transform) # Create data loaders self.train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) self.val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4) return self.train_loader, self.val_loader def create_model(self): """Initialize the Vision Transformer model""" self.model = timm.create_model( self.model_name, pretrained=True, num_classes=self.num_classes ) self.model = self.model.to(self.device) # Loss function and optimizer self.criterion = nn.CrossEntropyLoss() self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=2e-5, weight_decay=0.01 ) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=10 ) return self.model def train_epoch(self, epoch): """Train for one epoch""" self.model.train() total_loss = 0 correct = 0 total = 0 progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}') for batch_idx, (inputs, targets) in enumerate(progress_bar): inputs, targets = inputs.to(self.device), targets.to(self.device) self.optimizer.zero_grad() outputs = self.model(inputs) loss = self.criterion(outputs, targets) loss.backward() self.optimizer.step() total_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() progress_bar.set_postfix({ 'Loss': total_loss/(batch_idx+1), 'Acc': 100.*correct/total }) # Log to wandb wandb.log({ 'train_loss': loss.item(), 'train_acc': 100.*correct/total }) return total_loss/len(self.train_loader), 100.*correct/total def validate(self): """Validate the model""" self.model.eval() total_loss = 0 correct = 0 total = 0 with torch.no_grad(): for inputs, targets in tqdm(self.val_loader, desc='Validating'): inputs, targets = inputs.to(self.device), targets.to(self.device) outputs = self.model(inputs) loss = self.criterion(outputs, targets) total_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() accuracy = 100.*correct/total avg_loss = total_loss/len(self.val_loader) # Log to wandb wandb.log({ 'val_loss': avg_loss, 'val_acc': accuracy }) return avg_loss, accuracy def train(self, epochs=10): """Complete training process""" best_acc = 0 for epoch in range(epochs): train_loss, train_acc = self.train_epoch(epoch) val_loss, val_acc = self.validate() self.scheduler.step() print(f'\nEpoch {epoch}:') print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%') print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%') # Save best model if val_acc > best_acc: best_acc = val_acc torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'class_to_idx': self.class_to_idx }, 'best_model.pth') wandb.finish() def save_for_huggingface(self): """Save model in Hugging Face format""" # Load best model checkpoint = torch.load('best_model.pth') self.model.load_state_dict(checkpoint['model_state_dict']) # Save model and config self.model.save_pretrained('plant_disease_model') # Save class mapping idx_to_class = {v: k for k, v in self.class_to_idx.items()} pd.Series(idx_to_class).to_json('class_mapping.json')
if name == "main": classifier = PlantDiseaseClassifier(data_dir="path/to/dataset") classifier.prepare_data() classifier.create_model() classifier.train(epochs=10) classifier.save_for_huggingface()