Image Classification
karanravindra commited on
Commit
d3c734a
1 Parent(s): b4964a9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +58 -3
README.md CHANGED
@@ -1,3 +1,58 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - ylecun/mnist
5
+ pipeline_tag: image-classification
6
+ ---
7
+
8
+ # DigitNet
9
+
10
+ [![Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/karanravindra/digitnet)
11
+
12
+ **DigitNet** is a simple convolutional neural network designed to classify handwritten digits from the MNIST dataset. It is implemented in Python using the following libraries:
13
+
14
+ - **PyTorch** for GPU acceleration (specifically, using the MPS backend).
15
+ - **PyTorch Lightning** as a high-level wrapper for PyTorch to simplify training.
16
+
17
+ ## Dataset
18
+
19
+ ![Example image](./assets/examples.png)
20
+
21
+ The model is trained on the **MNIST** dataset, which consists of 60,000 training images and 10,000 test images of handwritten digits. The model achieves:
22
+
23
+ - **Top-1 accuracy**: 98.5% on the test set.
24
+ - **Top-2 accuracy**: 99.8% on a highly augmented test set.
25
+
26
+ > **Top-k accuracy** refers to the model's ability to include the correct label in its top `k` predictions. For example, a top-2 accuracy of 99.8% means that the correct label is one of the top two predictions 99.8% of the time.
27
+
28
+ ## Model Architecture
29
+
30
+ The architecture of DigitNet is inspired by:
31
+
32
+ - [ConvNeXt](https://arxiv.org/pdf/2201.03545)
33
+ - [ConvNet-v2](https://arxiv.org/pdf/2301.00808) (specifically, the Global Response Normalization (GRN) block)
34
+
35
+ The model, however, is **not** trained with the Fully Convolutional Masked Autoencoder (FCMAE).
36
+
37
+ ### Key Model Features:
38
+
39
+ - **Residual Depthwise Separable Convolutions**
40
+ - **1-1-3-1 bottleneck structure**
41
+ - ~500k parameters (intentionally overparameterized)
42
+ - Trained using **AdamW** optimizer
43
+
44
+ The model is trained with a batch size of 128 for 10 epochs.
45
+
46
+ ## Training
47
+
48
+ The model is optimized using a **cross-entropy loss** without label smoothing. This choice was made because [label smoothing can negatively impact teacher distillation](https://arxiv.org/pdf/1906.02629).
49
+
50
+ ### Results
51
+
52
+ ![Confusion matrix](./assets/cm.png)
53
+
54
+ The confusion matrix above shows the model's performance on the test set. The model performs well on most digits, with the most confusion between 4s, 7s, and 9s. It also seems to over predict 7s.
55
+
56
+ ## License
57
+
58
+ This project is licensed under the **MIT License**. See the [LICENSE](LICENSE) file for more details.