YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms import timm from transformers import ViTFeatureExtractor, ViTForImageClassification from pathlib import Path import pandas as pd import numpy as np from PIL import Image from sklearn.model_selection import train_test_split from tqdm.auto import tqdm import wandb

class PlantDiseaseDataset(Dataset):     def init(self, image_paths, labels, transform=None):         self.image_paths = image_paths         self.labels = labels         self.transform = transform             def len(self):         return len(self.image_paths)         def getitem(self, idx):         image_path = self.image_paths[idx]         image = Image.open(image_path).convert('RGB')         label = self.labels[idx]                 if self.transform:             image = self.transform(image)                     return image, label

class PlantDiseaseClassifier:     def init(self, data_dir, model_name='vit_base_patch16_224', num_classes=38):         self.data_dir = Path(data_dir)         self.model_name = model_name         self.num_classes = num_classes         self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')                 # Initialize wandb         wandb.init(project="plant-disease-classification")             def prepare_data(self):         """Prepare dataset and create data loaders"""         # Data augmentation and normalization for training         train_transform = transforms.Compose([             transforms.RandomResizedCrop(224),             transforms.RandomHorizontalFlip(),             transforms.RandomVerticalFlip(),             transforms.RandomRotation(20),             transforms.ColorJitter(brightness=0.2, contrast=0.2),             transforms.ToTensor(),             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])         ])                 # Just normalization for validation/testing         val_transform = transforms.Compose([             transforms.Resize(256),             transforms.CenterCrop(224),             transforms.ToTensor(),             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])         ])                 # Collect all image paths and labels         image_paths = []         labels = []         self.class_to_idx = {}                 for idx, class_dir in enumerate(sorted(self.data_dir.glob(''))):             if class_dir.is_dir():                 self.class_to_idx[class_dir.name] = idx                 for img_path in class_dir.glob('.jpg'):                     image_paths.append(str(img_path))                     labels.append(idx)                 # Split data         train_paths, val_paths, train_labels, val_labels = train_test_split(             image_paths, labels, test_size=0.2, stratify=labels, random_state=42         )                 # Create datasets         train_dataset = PlantDiseaseDataset(train_paths, train_labels, train_transform)         val_dataset = PlantDiseaseDataset(val_paths, val_labels, val_transform)                 # Create data loaders         self.train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)         self.val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)                 return self.train_loader, self.val_loader         def create_model(self):         """Initialize the Vision Transformer model"""         self.model = timm.create_model(             self.model_name,             pretrained=True,             num_classes=self.num_classes         )         self.model = self.model.to(self.device)                 # Loss function and optimizer         self.criterion = nn.CrossEntropyLoss()         self.optimizer = torch.optim.AdamW(             self.model.parameters(),             lr=2e-5,             weight_decay=0.01         )         self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(             self.optimizer,             T_max=10         )                 return self.model         def train_epoch(self, epoch):         """Train for one epoch"""         self.model.train()         total_loss = 0         correct = 0         total = 0                 progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')                 for batch_idx, (inputs, targets) in enumerate(progress_bar):             inputs, targets = inputs.to(self.device), targets.to(self.device)                         self.optimizer.zero_grad()             outputs = self.model(inputs)             loss = self.criterion(outputs, targets)                         loss.backward()             self.optimizer.step()                         total_loss += loss.item()             _, predicted = outputs.max(1)             total += targets.size(0)             correct += predicted.eq(targets).sum().item()                         progress_bar.set_postfix({                 'Loss': total_loss/(batch_idx+1),                 'Acc': 100.*correct/total             })                         # Log to wandb             wandb.log({                 'train_loss': loss.item(),                 'train_acc': 100.*correct/total             })                     return total_loss/len(self.train_loader), 100.*correct/total         def validate(self):         """Validate the model"""         self.model.eval()         total_loss = 0         correct = 0         total = 0                 with torch.no_grad():             for inputs, targets in tqdm(self.val_loader, desc='Validating'):                 inputs, targets = inputs.to(self.device), targets.to(self.device)                 outputs = self.model(inputs)                 loss = self.criterion(outputs, targets)                                 total_loss += loss.item()                 _, predicted = outputs.max(1)                 total += targets.size(0)                 correct += predicted.eq(targets).sum().item()                         accuracy = 100.*correct/total         avg_loss = total_loss/len(self.val_loader)                 # Log to wandb         wandb.log({             'val_loss': avg_loss,             'val_acc': accuracy         })                 return avg_loss, accuracy         def train(self, epochs=10):         """Complete training process"""         best_acc = 0                 for epoch in range(epochs):             train_loss, train_acc = self.train_epoch(epoch)             val_loss, val_acc = self.validate()             self.scheduler.step()                         print(f'\nEpoch {epoch}:')             print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')             print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')                         # Save best model             if val_acc > best_acc:                 best_acc = val_acc                 torch.save({                     'model_state_dict': self.model.state_dict(),                     'optimizer_state_dict': self.optimizer.state_dict(),                     'class_to_idx': self.class_to_idx                 }, 'best_model.pth')                 wandb.finish()         def save_for_huggingface(self):         """Save model in Hugging Face format"""         # Load best model         checkpoint = torch.load('best_model.pth')         self.model.load_state_dict(checkpoint['model_state_dict'])                 # Save model and config         self.model.save_pretrained('plant_disease_model')                 # Save class mapping         idx_to_class = {v: k for k, v in self.class_to_idx.items()}         pd.Series(idx_to_class).to_json('class_mapping.json')

if name == "main":     classifier = PlantDiseaseClassifier(data_dir="path/to/dataset")     classifier.prepare_data()     classifier.create_model()     classifier.train(epochs=10)     classifier.save_for_huggingface()

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.