Spaces:
Running
Running
| """ | |
| Dataset loading and preprocessing for crop disease detection | |
| """ | |
| import os | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| from pathlib import Path | |
| def get_transforms(split='train', input_size=224): | |
| """ | |
| Get image transforms for different dataset splits | |
| Args: | |
| split: 'train', 'val', or 'test' | |
| input_size: Input image size (default: 224) | |
| Returns: | |
| transforms.Compose: Composed transforms | |
| """ | |
| if split == 'train': | |
| # Training transforms with augmentation | |
| return transforms.Compose([ | |
| transforms.Resize((input_size, input_size)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomVerticalFlip(p=0.3), | |
| transforms.RandomRotation(degrees=15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| else: | |
| # Validation/test transforms (no augmentation) | |
| return transforms.Compose([ | |
| transforms.Resize((input_size, input_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def get_inference_transforms(input_size=224): | |
| """ | |
| Get transforms for inference (prediction) | |
| Args: | |
| input_size: Input image size (default: 224) | |
| Returns: | |
| transforms.Compose: Composed transforms for inference | |
| """ | |
| return transforms.Compose([ | |
| transforms.Resize((input_size, input_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| class CropDiseaseDataset(Dataset): | |
| """Custom dataset for crop disease images""" | |
| def __init__(self, data_dir, transform=None, class_to_idx=None): | |
| """ | |
| Args: | |
| data_dir: Path to dataset directory (train/val/test) | |
| transform: Optional transform to be applied on images | |
| class_to_idx: Dictionary mapping class names to indices | |
| """ | |
| self.data_dir = Path(data_dir) | |
| self.transform = transform | |
| # Get all image files and their labels | |
| self.samples = [] | |
| self.classes = [] | |
| # Scan all class directories | |
| for class_dir in sorted(self.data_dir.iterdir()): | |
| if class_dir.is_dir() and not class_dir.name.startswith('.'): | |
| self.classes.append(class_dir.name) | |
| # Create class to index mapping if not provided | |
| if class_to_idx is None: | |
| self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)} | |
| else: | |
| self.class_to_idx = class_to_idx | |
| # Collect all image samples | |
| image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'} | |
| for class_name in self.classes: | |
| class_dir = self.data_dir / class_name | |
| class_idx = self.class_to_idx[class_name] | |
| for img_path in class_dir.iterdir(): | |
| if img_path.suffix.lower() in image_extensions: | |
| self.samples.append((str(img_path), class_idx)) | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| img_path, label = self.samples[idx] | |
| # Load image | |
| try: | |
| image = Image.open(img_path).convert('RGB') | |
| except Exception as e: | |
| # Create a dummy image if file doesn't exist or is corrupted | |
| print(f"Warning: Could not load {img_path}, creating dummy image") | |
| image = Image.new('RGB', (224, 224), color=(128, 128, 128)) | |
| # Apply transforms | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def get_class_names(self): | |
| """Return list of class names""" | |
| return self.classes | |
| def get_class_to_idx(self): | |
| """Return class to index mapping""" | |
| return self.class_to_idx | |
| def get_data_transforms(): | |
| """Get data transforms for training and validation""" | |
| # ImageNet normalization values | |
| normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| # Training transforms with data augmentation | |
| train_transforms = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.RandomResizedCrop(224), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomVerticalFlip(p=0.3), | |
| transforms.RandomRotation(degrees=15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
| transforms.ToTensor(), | |
| normalize | |
| ]) | |
| # Validation/Test transforms (no augmentation) | |
| val_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| normalize | |
| ]) | |
| return train_transforms, val_transforms | |
| def create_data_loaders(data_dir, batch_size=32, num_workers=0): | |
| """Create data loaders for training, validation, and testing""" | |
| train_transforms, val_transforms = get_data_transforms() | |
| pin_memory = torch.cuda.is_available() | |
| # Create datasets | |
| train_dataset = CropDiseaseDataset( | |
| data_dir=os.path.join(data_dir, 'train'), | |
| transform=train_transforms | |
| ) | |
| val_dataset = CropDiseaseDataset( | |
| data_dir=os.path.join(data_dir, 'val'), | |
| transform=val_transforms, | |
| class_to_idx=train_dataset.get_class_to_idx() | |
| ) | |
| test_dataset = CropDiseaseDataset( | |
| data_dir=os.path.join(data_dir, 'test'), | |
| transform=val_transforms, | |
| class_to_idx=train_dataset.get_class_to_idx() | |
| ) | |
| # Create data loaders | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory | |
| ) | |
| return train_loader, val_loader, test_loader, train_dataset.get_class_names() | |
| def get_class_weights(data_dir): | |
| """Calculate class weights for handling imbalanced datasets""" | |
| train_dataset = CropDiseaseDataset(data_dir=os.path.join(data_dir, 'train')) | |
| # Count samples per class | |
| class_counts = {} | |
| for _, label in train_dataset.samples: | |
| class_name = train_dataset.classes[label] | |
| class_counts[class_name] = class_counts.get(class_name, 0) + 1 | |
| # Calculate weights (inverse frequency) | |
| total_samples = len(train_dataset.samples) | |
| num_classes = len(train_dataset.classes) | |
| class_weights = [] | |
| for class_name in train_dataset.classes: | |
| count = class_counts.get(class_name, 1) | |
| weight = total_samples / (num_classes * count) | |
| class_weights.append(weight) | |
| return torch.FloatTensor(class_weights) | |
| if __name__ == "__main__": | |
| # Test the dataset loading | |
| data_dir = "data" | |
| try: | |
| train_loader, val_loader, test_loader, class_names = create_data_loaders(data_dir, batch_size=4) | |
| print(f"Dataset loaded successfully!") | |
| print(f"Number of classes: {len(class_names)}") | |
| print(f"Classes: {class_names}") | |
| print(f"Train batches: {len(train_loader)}") | |
| print(f"Val batches: {len(val_loader)}") | |
| print(f"Test batches: {len(test_loader)}") | |
| # Test loading a batch | |
| for images, labels in train_loader: | |
| print(f"Batch shape: {images.shape}") | |
| print(f"Label shape: {labels.shape}") | |
| break | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| print("Make sure the dataset is properly organized in data/train, data/val, data/test") | |