karanravindra
commited on
Commit
•
d3c734a
1
Parent(s):
b4964a9
Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,58 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
datasets:
|
4 |
+
- ylecun/mnist
|
5 |
+
pipeline_tag: image-classification
|
6 |
+
---
|
7 |
+
|
8 |
+
# DigitNet
|
9 |
+
|
10 |
+
[![Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/karanravindra/digitnet)
|
11 |
+
|
12 |
+
**DigitNet** is a simple convolutional neural network designed to classify handwritten digits from the MNIST dataset. It is implemented in Python using the following libraries:
|
13 |
+
|
14 |
+
- **PyTorch** for GPU acceleration (specifically, using the MPS backend).
|
15 |
+
- **PyTorch Lightning** as a high-level wrapper for PyTorch to simplify training.
|
16 |
+
|
17 |
+
## Dataset
|
18 |
+
|
19 |
+
![Example image](./assets/examples.png)
|
20 |
+
|
21 |
+
The model is trained on the **MNIST** dataset, which consists of 60,000 training images and 10,000 test images of handwritten digits. The model achieves:
|
22 |
+
|
23 |
+
- **Top-1 accuracy**: 98.5% on the test set.
|
24 |
+
- **Top-2 accuracy**: 99.8% on a highly augmented test set.
|
25 |
+
|
26 |
+
> **Top-k accuracy** refers to the model's ability to include the correct label in its top `k` predictions. For example, a top-2 accuracy of 99.8% means that the correct label is one of the top two predictions 99.8% of the time.
|
27 |
+
|
28 |
+
## Model Architecture
|
29 |
+
|
30 |
+
The architecture of DigitNet is inspired by:
|
31 |
+
|
32 |
+
- [ConvNeXt](https://arxiv.org/pdf/2201.03545)
|
33 |
+
- [ConvNet-v2](https://arxiv.org/pdf/2301.00808) (specifically, the Global Response Normalization (GRN) block)
|
34 |
+
|
35 |
+
The model, however, is **not** trained with the Fully Convolutional Masked Autoencoder (FCMAE).
|
36 |
+
|
37 |
+
### Key Model Features:
|
38 |
+
|
39 |
+
- **Residual Depthwise Separable Convolutions**
|
40 |
+
- **1-1-3-1 bottleneck structure**
|
41 |
+
- ~500k parameters (intentionally overparameterized)
|
42 |
+
- Trained using **AdamW** optimizer
|
43 |
+
|
44 |
+
The model is trained with a batch size of 128 for 10 epochs.
|
45 |
+
|
46 |
+
## Training
|
47 |
+
|
48 |
+
The model is optimized using a **cross-entropy loss** without label smoothing. This choice was made because [label smoothing can negatively impact teacher distillation](https://arxiv.org/pdf/1906.02629).
|
49 |
+
|
50 |
+
### Results
|
51 |
+
|
52 |
+
![Confusion matrix](./assets/cm.png)
|
53 |
+
|
54 |
+
The confusion matrix above shows the model's performance on the test set. The model performs well on most digits, with the most confusion between 4s, 7s, and 9s. It also seems to over predict 7s.
|
55 |
+
|
56 |
+
## License
|
57 |
+
|
58 |
+
This project is licensed under the **MIT License**. See the [LICENSE](LICENSE) file for more details.
|