iqramukhtiar commited on
Commit
78446ba
·
verified ·
1 Parent(s): faaf848

Upload 3 files

Browse files
download_huggingface_dataset.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+ from datasets import load_dataset
5
+ import shutil
6
+
7
+ def download_plantvillage_from_huggingface():
8
+ """
9
+ Downloads the PlantVillage dataset from Hugging Face and organizes it for training.
10
+ """
11
+ print("Downloading PlantVillage dataset from Hugging Face...")
12
+
13
+ # Create directory for the dataset
14
+ os.makedirs('PlantVillage', exist_ok=True)
15
+
16
+ try:
17
+ # Load the dataset from Hugging Face
18
+ dataset = load_dataset("GVJahnavi/PlantVillage_dataset")
19
+ print(f"Dataset loaded successfully with {len(dataset['train'])} training samples")
20
+
21
+ # Get unique labels
22
+ labels = dataset['train'].features['label'].names
23
+ print(f"Found {len(labels)} classes: {labels}")
24
+
25
+ # Create directories for each class
26
+ for label_idx, label_name in enumerate(labels):
27
+ label_dir = os.path.join('PlantVillage', label_name)
28
+ os.makedirs(label_dir, exist_ok=True)
29
+
30
+ # Get samples for this class
31
+ class_samples = dataset['train'].filter(lambda example: example['label'] == label_idx)
32
+ print(f"Processing class {label_name} with {len(class_samples)} samples")
33
+
34
+ # Save images for this class
35
+ for i, sample in enumerate(tqdm(class_samples, desc=f"Saving {label_name}")):
36
+ img = sample['image']
37
+ img_path = os.path.join(label_dir, f"{label_name}_{i}.jpg")
38
+ img.save(img_path)
39
+
40
+ # Save class names to a file
41
+ with open('class_names.json', 'w') as f:
42
+ import json
43
+ json.dump(labels, f)
44
+
45
+ print("Dataset downloaded and organized successfully")
46
+ return True
47
+
48
+ except Exception as e:
49
+ print(f"Error downloading dataset from Hugging Face: {e}")
50
+ return False
51
+
52
+ if __name__ == "__main__":
53
+ download_plantvillage_from_huggingface()
run_huggingface_training.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ def main():
5
+ """
6
+ Main script to run the training process using the Hugging Face dataset.
7
+ """
8
+ print("=== Plant Disease Model Training with Hugging Face Dataset ===")
9
+
10
+ # First, check if the datasets library is installed
11
+ try:
12
+ import datasets
13
+ print("Hugging Face datasets library is installed.")
14
+ except ImportError:
15
+ print("Error: Hugging Face datasets library is not installed.")
16
+ print("Installing required packages...")
17
+ os.system("pip install datasets")
18
+ print("Please run this script again after installation.")
19
+ sys.exit(1)
20
+
21
+ # Step 1: Download the PlantVillage dataset from Hugging Face
22
+ print("\nStep 1: Downloading dataset from Hugging Face")
23
+ try:
24
+ import download_huggingface_dataset
25
+ success = download_huggingface_dataset.download_plantvillage_from_huggingface()
26
+ if success:
27
+ print("Dataset downloaded successfully from Hugging Face")
28
+ else:
29
+ print("Failed to download dataset from Hugging Face")
30
+ sys.exit(1)
31
+ except Exception as e:
32
+ print(f"Error during dataset download: {e}")
33
+ sys.exit(1)
34
+
35
+ # Step 2: Train the model
36
+ print("\nStep 2: Training model with Hugging Face dataset")
37
+ try:
38
+ import train_huggingface_model
39
+ train_huggingface_model.train_model_with_huggingface_data()
40
+ except Exception as e:
41
+ print(f"Error during model training: {e}")
42
+ sys.exit(1)
43
+
44
+ print("\nTraining completed successfully!")
45
+ print("You can now run the application with 'python app.py'")
46
+
47
+ if __name__ == "__main__":
48
+ main()
train_huggingface_model.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader, random_split
6
+ from torchvision import datasets, models, transforms
7
+ import json
8
+ from tqdm import tqdm
9
+ import time
10
+
11
+ def train_model_with_huggingface_data():
12
+ """
13
+ Trains a model using the PlantVillage dataset downloaded from Hugging Face.
14
+ """
15
+ print("Starting model training with Hugging Face dataset...")
16
+
17
+ # Set device
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ print(f"Using device: {device}")
20
+
21
+ # Data transformations
22
+ data_transforms = {
23
+ 'train': transforms.Compose([
24
+ transforms.Resize((224, 224)),
25
+ transforms.RandomHorizontalFlip(),
26
+ transforms.RandomRotation(15),
27
+ transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
30
+ ]),
31
+ 'val': transforms.Compose([
32
+ transforms.Resize((224, 224)),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
35
+ ]),
36
+ }
37
+
38
+ # Load the dataset
39
+ print("Loading dataset...")
40
+ try:
41
+ dataset_path = 'PlantVillage'
42
+ if not os.path.exists(dataset_path):
43
+ print(f"Error: Dataset directory {dataset_path} not found.")
44
+ print("Please run download_huggingface_dataset.py first.")
45
+ return
46
+
47
+ dataset = datasets.ImageFolder(dataset_path, transform=data_transforms['train'])
48
+
49
+ # Split into train and validation sets
50
+ train_size = int(0.8 * len(dataset))
51
+ val_size = len(dataset) - train_size
52
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
53
+
54
+ # Apply different transforms to the splits
55
+ train_dataset.dataset.transform = data_transforms['train']
56
+ val_dataset.dataset.transform = data_transforms['val']
57
+
58
+ # Create data loaders
59
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
60
+ val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
61
+
62
+ # Save class names
63
+ class_names = dataset.classes
64
+ with open('class_names.json', 'w') as f:
65
+ json.dump(class_names, f)
66
+
67
+ print(f"Dataset loaded with {len(class_names)} classes")
68
+ print(f"Training set: {len(train_dataset)} images")
69
+ print(f"Validation set: {len(val_dataset)} images")
70
+
71
+ # Load a pre-trained model
72
+ print("Loading pre-trained model...")
73
+ model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
74
+
75
+ # Modify the final layer for our number of classes
76
+ num_ftrs = model.fc.in_features
77
+ model.fc = nn.Linear(num_ftrs, len(class_names))
78
+
79
+ model = model.to(device)
80
+
81
+ # Define loss function and optimizer
82
+ criterion = nn.CrossEntropyLoss()
83
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
84
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
85
+
86
+ # Train the model
87
+ num_epochs = 10
88
+ best_acc = 0.0
89
+
90
+ print(f"Starting training for {num_epochs} epochs...")
91
+ for epoch in range(num_epochs):
92
+ print(f'Epoch {epoch+1}/{num_epochs}')
93
+ print('-' * 10)
94
+
95
+ # Training phase
96
+ model.train()
97
+ running_loss = 0.0
98
+ running_corrects = 0
99
+
100
+ # Iterate over data
101
+ for inputs, labels in tqdm(train_loader, desc=f"Training"):
102
+ inputs = inputs.to(device)
103
+ labels = labels.to(device)
104
+
105
+ # Zero the parameter gradients
106
+ optimizer.zero_grad()
107
+
108
+ # Forward pass
109
+ outputs = model(inputs)
110
+ _, preds = torch.max(outputs, 1)
111
+ loss = criterion(outputs, labels)
112
+
113
+ # Backward + optimize
114
+ loss.backward()
115
+ optimizer.step()
116
+
117
+ # Statistics
118
+ running_loss += loss.item() * inputs.size(0)
119
+ running_corrects += torch.sum(preds == labels.data)
120
+
121
+ scheduler.step()
122
+
123
+ epoch_loss = running_loss / len(train_dataset)
124
+ epoch_acc = running_corrects.double() / len(train_dataset)
125
+
126
+ print(f'Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
127
+
128
+ # Validation phase
129
+ model.eval()
130
+ running_loss = 0.0
131
+ running_corrects = 0
132
+
133
+ # Iterate over data
134
+ for inputs, labels in tqdm(val_loader, desc=f"Validation"):
135
+ inputs = inputs.to(device)
136
+ labels = labels.to(device)
137
+
138
+ # Forward pass
139
+ with torch.no_grad():
140
+ outputs = model(inputs)
141
+ _, preds = torch.max(outputs, 1)
142
+ loss = criterion(outputs, labels)
143
+
144
+ # Statistics
145
+ running_loss += loss.item() * inputs.size(0)
146
+ running_corrects += torch.sum(preds == labels.data)
147
+
148
+ epoch_loss = running_loss / len(val_dataset)
149
+ epoch_acc = running_corrects.double() / len(val_dataset)
150
+
151
+ print(f'Validation Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
152
+
153
+ # Save the best model
154
+ if epoch_acc > best_acc:
155
+ best_acc = epoch_acc
156
+ torch.save(model.state_dict(), 'plant_disease_model.pth')
157
+ print(f"Saved new best model with accuracy: {best_acc:.4f}")
158
+
159
+ print()
160
+
161
+ print(f'Best val Acc: {best_acc:.4f}')
162
+ print('Model saved as plant_disease_model.pth')
163
+
164
+ except Exception as e:
165
+ print(f"Error during training: {e}")
166
+
167
+ if __name__ == "__main__":
168
+ start_time = time.time()
169
+ train_model_with_huggingface_data()
170
+ end_time = time.time()
171
+ print(f"Training completed in {(end_time - start_time)/60:.2f} minutes")