deepshah23 commited on
Commit
e1ab6a5
·
verified ·
1 Parent(s): adf39c3

Upload train_digit_classifier.py

Browse files
Files changed (1) hide show
  1. train_digit_classifier.py +300 -0
train_digit_classifier.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train_digit_classifier.py
3
+
4
+ A fully documented training script for a convolutional neural network (CNN)
5
+ classifier trained on MNIST + EMNIST digits + blank images.
6
+
7
+ Author: Deep Shah
8
+ License: GPL-3.0
9
+ """
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.optim as optim
15
+ import torchvision
16
+ import torchvision.transforms as transforms
17
+ from torch.utils.data import DataLoader, Dataset, TensorDataset
18
+ from sklearn.model_selection import train_test_split
19
+ import os
20
+
21
+ # ----------------------------------------------------------------------
22
+ # 1. Reproducibility Setup
23
+ # ----------------------------------------------------------------------
24
+
25
+ # Set fixed seeds to make results deterministic (important for debugging and reproducibility)
26
+ torch.manual_seed(42)
27
+ np.random.seed(42)
28
+
29
+ # ----------------------------------------------------------------------
30
+ # 2. Device Selection
31
+ # ----------------------------------------------------------------------
32
+
33
+ # Automatically use GPU if available; fallback to CPU otherwise
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ print(f"[INFO] Using device: {device}")
36
+
37
+ # ----------------------------------------------------------------------
38
+ # 3. EMNIST Loader (Custom Dataset class)
39
+ # ----------------------------------------------------------------------
40
+
41
+ class EMNISTDigitsDataset(Dataset):
42
+ """
43
+ A PyTorch-compatible wrapper for the EMNIST digits dataset loaded via TensorFlow Datasets.
44
+ Ensures data is shaped correctly and optionally transformed.
45
+ """
46
+
47
+ def __init__(self, split="train", transform=None):
48
+ import tensorflow_datasets as tfds
49
+ ds = tfds.load("emnist/digits", split=split, as_supervised=True)
50
+ self.images = []
51
+ self.labels = []
52
+ for image, label in tfds.as_numpy(ds):
53
+ if image.ndim == 2:
54
+ image = image[..., np.newaxis]
55
+ elif image.ndim == 4 and image.shape[0] == 1:
56
+ image = image[0]
57
+ self.images.append(image)
58
+ self.labels.append(label)
59
+ self.images = np.array(self.images, dtype=np.float32) / 255.0 # Normalize to [0,1]
60
+ self.labels = np.array(self.labels, dtype=np.int64)
61
+ self.transform = transform
62
+
63
+ def __len__(self):
64
+ return len(self.images)
65
+
66
+ def __getitem__(self, idx):
67
+ image = self.images[idx]
68
+ label = self.labels[idx]
69
+ if self.transform:
70
+ image = self.transform(torch.tensor(image.transpose(2, 0, 1))).transpose(1, 2).numpy()
71
+ return torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32), torch.tensor(label, dtype=torch.long)
72
+
73
+ # ----------------------------------------------------------------------
74
+ # 4. Data Augmentation Strategy
75
+ # ----------------------------------------------------------------------
76
+
77
+ # We use a modest augmentation strategy to improve generalization
78
+ train_transform = transforms.Compose([
79
+ transforms.ToPILImage(),
80
+ transforms.RandomRotation(10), # Handle slanted handwriting
81
+ transforms.RandomAffine(degrees=0, scale=(0.9, 1.1), translate=(0.1, 0.1)), # Simulate slight distortions
82
+ transforms.ToTensor()
83
+ ])
84
+
85
+ # ----------------------------------------------------------------------
86
+ # 5. Load Datasets (MNIST + EMNIST + Blank)
87
+ # ----------------------------------------------------------------------
88
+
89
+ # Load MNIST
90
+ mnist_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True)
91
+ mnist_images = mnist_dataset.data.numpy().astype(np.float32) / 255.0
92
+ mnist_images = mnist_images[..., np.newaxis]
93
+ mnist_labels = mnist_dataset.targets.numpy()
94
+
95
+ # Load EMNIST
96
+ emnist_dataset = EMNISTDigitsDataset(split="train", transform=None)
97
+ emnist_images = emnist_dataset.images
98
+ emnist_labels = emnist_dataset.labels
99
+
100
+ # Create blank (all-black) 28x28 images, labeled with class 10
101
+ x_blank = np.zeros((5000, 28, 28, 1), dtype=np.float32)
102
+ y_blank = np.full((5000,), 10, dtype=np.int64)
103
+
104
+ # Combine all datasets
105
+ x_combined = np.concatenate([mnist_images, emnist_images, x_blank], axis=0)
106
+ y_combined = np.concatenate([mnist_labels, emnist_labels, y_blank], axis=0)
107
+
108
+ # Shuffle for randomness
109
+ indices = np.random.permutation(len(x_combined))
110
+ x_combined = x_combined[indices]
111
+ y_combined = y_combined[indices]
112
+
113
+ # ----------------------------------------------------------------------
114
+ # 6. Train/Validation Split
115
+ # ----------------------------------------------------------------------
116
+
117
+ x_train, x_val, y_train, y_val = train_test_split(
118
+ x_combined, y_combined, test_size=0.1, random_state=42
119
+ )
120
+
121
+ # Convert to PyTorch format
122
+ train_dataset = TensorDataset(
123
+ torch.tensor(x_train.transpose(0, 3, 1, 2), dtype=torch.float32),
124
+ torch.tensor(y_train, dtype=torch.long)
125
+ )
126
+ val_dataset = TensorDataset(
127
+ torch.tensor(x_val.transpose(0, 3, 1, 2), dtype=torch.float32),
128
+ torch.tensor(y_val, dtype=torch.long)
129
+ )
130
+
131
+ train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
132
+ val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
133
+
134
+ # ----------------------------------------------------------------------
135
+ # 7. CNN Architecture
136
+ # ----------------------------------------------------------------------
137
+
138
+ class CNN(nn.Module):
139
+ """
140
+ This CNN is designed to:
141
+ - Use 3 convolutional blocks with increasing depth (32 -> 64 -> 128)
142
+ - Use BatchNorm to stabilize training
143
+ - Use Dropout to prevent overfitting
144
+ - Flatten and use 2 dense layers to classify
145
+ """
146
+
147
+ def __init__(self):
148
+ super().__init__()
149
+ self.features = nn.Sequential(
150
+ nn.Conv2d(1, 32, 3, padding=1), # Small receptive field
151
+ nn.BatchNorm2d(32),
152
+ nn.ReLU(),
153
+ nn.Conv2d(32, 64, 3, padding=1), # Slightly deeper
154
+ nn.BatchNorm2d(64),
155
+ nn.ReLU(),
156
+ nn.MaxPool2d(2, 2),
157
+ nn.Dropout(0.1), # Helps regularize
158
+ nn.Conv2d(64, 128, 3, padding=1),
159
+ nn.BatchNorm2d(128),
160
+ nn.ReLU(),
161
+ nn.MaxPool2d(2, 2),
162
+ nn.Dropout(0.1)
163
+ )
164
+ self.classifier = nn.Sequential(
165
+ nn.Flatten(),
166
+ nn.Linear(128 * 7 * 7, 128),
167
+ nn.BatchNorm1d(128),
168
+ nn.ReLU(),
169
+ nn.Dropout(0.2),
170
+ nn.Linear(128, 11) # 0-9 digits + blank (class 10)
171
+ )
172
+
173
+ def forward(self, x):
174
+ return self.classifier(self.features(x))
175
+
176
+ model = CNN().to(device)
177
+
178
+ # ----------------------------------------------------------------------
179
+ # 8. Training Configuration
180
+ # ----------------------------------------------------------------------
181
+
182
+ # CrossEntropyLoss is standard for multi-class classification
183
+ criterion = nn.CrossEntropyLoss()
184
+
185
+ # Adam is used because it's efficient for noisy gradients & fast convergence
186
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
187
+
188
+ # ReduceLROnPlateau reduces LR when validation loss plateaus (adaptive control)
189
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=2, min_lr=1e-6)
190
+
191
+ # Early stopping is used to prevent overfitting and wasted training
192
+ patience = 5
193
+ patience_counter = 0
194
+ best_val_loss = float("inf")
195
+ best_model_state = None
196
+
197
+ # ----------------------------------------------------------------------
198
+ # 9. Training Loop
199
+ # ----------------------------------------------------------------------
200
+
201
+ for epoch in range(1, 51):
202
+ model.train()
203
+ running_loss = 0
204
+ correct = 0
205
+ total = 0
206
+
207
+ for images, labels in train_loader:
208
+ images, labels = images.to(device), labels.to(device)
209
+
210
+ # Apply data augmentation on CPU
211
+ for i in range(len(images)):
212
+ images[i] = train_transform(images[i].cpu()).to(device)
213
+
214
+ optimizer.zero_grad()
215
+ outputs = model(images)
216
+ loss = criterion(outputs, labels)
217
+ loss.backward()
218
+ optimizer.step()
219
+
220
+ running_loss += loss.item()
221
+ _, predicted = torch.max(outputs, 1)
222
+ total += labels.size(0)
223
+ correct += (predicted == labels).sum().item()
224
+
225
+ train_acc = 100 * correct / total
226
+ train_loss = running_loss / len(train_loader)
227
+
228
+ # ----------------
229
+ # Validation phase
230
+ # ----------------
231
+ model.eval()
232
+ val_loss = 0
233
+ val_correct = 0
234
+ val_total = 0
235
+ with torch.no_grad():
236
+ for images, labels in val_loader:
237
+ images, labels = images.to(device), labels.to(device)
238
+ outputs = model(images)
239
+ loss = criterion(outputs, labels)
240
+ val_loss += loss.item()
241
+ _, predicted = torch.max(outputs, 1)
242
+ val_total += labels.size(0)
243
+ val_correct += (predicted == labels).sum().item()
244
+
245
+ val_acc = 100 * val_correct / val_total
246
+ val_loss /= len(val_loader)
247
+
248
+ print(f"Epoch {epoch:02d}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, "
249
+ f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%")
250
+
251
+ # Adjust learning rate if plateau
252
+ scheduler.step(val_loss)
253
+
254
+ # Save best model
255
+ if val_loss < best_val_loss:
256
+ best_val_loss = val_loss
257
+ best_model_state = model.state_dict()
258
+ patience_counter = 0
259
+ else:
260
+ patience_counter += 1
261
+ if patience_counter >= patience:
262
+ print("[INFO] Early stopping triggered.")
263
+ break
264
+
265
+ # Load best model
266
+ model.load_state_dict(best_model_state)
267
+
268
+ # Save PyTorch weights
269
+ torch.save(model.state_dict(), "mnist_emnist_blank_cnn_v1.pth")
270
+ print("[INFO] Model weights saved as mnist_emnist_blank_cnn_v1.pth")
271
+
272
+ # Convert to TorchScript for deployment (required by Hugging Face Inference API)
273
+ model.eval()
274
+ example_input = torch.randn(1, 1, 28, 28).to(device)
275
+ scripted_model = torch.jit.trace(model, example_input)
276
+ scripted_model.save("mnist_emnist_blank_cnn_v1.pt")
277
+ print("[INFO] TorchScript model saved as mnist_emnist_blank_cnn_v1.pt")
278
+
279
+ # ONNX export
280
+ # We move to CPU just for export (then restore the device).
281
+ prev_device = next(model.parameters()).device
282
+ try:
283
+ model_cpu = model.to("cpu").eval()
284
+ dummy = torch.randn(1, 1, 28, 28) # match input shape
285
+
286
+ onnx_path = "mnist_emnist_blank_cnn_v1.onnx"
287
+ torch.onnx.export(
288
+ model_cpu,
289
+ dummy,
290
+ onnx_path,
291
+ export_params=True,
292
+ opset_version=13,
293
+ do_constant_folding=True,
294
+ input_names=["input"],
295
+ output_names=["logits"],
296
+ dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}},
297
+ )
298
+ print(f"[INFO] ONNX model saved as {onnx_path}")
299
+ finally:
300
+ model.to(prev_device).eval() # restore original device