added "Usage" section to README.md
Browse files
README.md
CHANGED
@@ -19,6 +19,65 @@ Code: https://github.com/sambitmukherjee/dlwpt-exercises/blob/main/chapter_8/exe
|
|
19 |
|
20 |
Experiment tracking: https://wandb.ai/sadhaklal/custom-cnn-cifar2
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
## Metric
|
23 |
|
24 |
Accuracy on `cifar2_val`: 0.8995
|
|
|
19 |
|
20 |
Experiment tracking: https://wandb.ai/sadhaklal/custom-cnn-cifar2
|
21 |
|
22 |
+
## Usage
|
23 |
+
|
24 |
+
```
|
25 |
+
!pip install -q datasets
|
26 |
+
|
27 |
+
from datasets import load_dataset
|
28 |
+
|
29 |
+
cifar10 = load_dataset("cifar10")
|
30 |
+
label_map = {0: 0, 2: 1}
|
31 |
+
class_names = ['airplane', 'bird']
|
32 |
+
cifar2_train = [(example['img'], label_map[example['label']]) for example in cifar10['train'] if example['label'] in [0, 2]]
|
33 |
+
cifar2_val = [(example['img'], label_map[example['label']]) for example in cifar10['test'] if example['label'] in [0, 2]]
|
34 |
+
|
35 |
+
example = cifar2_val[0]
|
36 |
+
img, label = example
|
37 |
+
|
38 |
+
import torch
|
39 |
+
from torchvision.transforms import v2
|
40 |
+
|
41 |
+
tfms = v2.Compose([
|
42 |
+
v2.ToImage(),
|
43 |
+
v2.ToDtype(torch.float32, scale=True),
|
44 |
+
v2.Normalize(mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616])
|
45 |
+
])
|
46 |
+
img = tfms(img)
|
47 |
+
batch = img.unsqueeze(0)
|
48 |
+
|
49 |
+
import torch.nn as nn
|
50 |
+
import torch.nn.functional as F
|
51 |
+
from huggingface_hub import PyTorchModelHubMixin
|
52 |
+
|
53 |
+
class Net(nn.Module, PyTorchModelHubMixin):
|
54 |
+
def __init__(self):
|
55 |
+
super().__init__()
|
56 |
+
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=1)
|
57 |
+
self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1, stride=1)
|
58 |
+
self.fc1 = nn.Linear(8 * 8 * 8, 32)
|
59 |
+
self.fc2 = nn.Linear(32, 2)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
out = F.max_pool2d(torch.tanh(self.conv1(x)), kernel_size=2, stride=2) # Output shape: (batch_size, 16, 16, 16)
|
63 |
+
out = F.max_pool2d(torch.tanh(self.conv2(out)), kernel_size=2, stride=2) # Output shape: (batch_size, 8, 8, 8)
|
64 |
+
out = out.view(-1, 8 * 8 * 8) # Output shape: (batch_size, 512)
|
65 |
+
out = torch.tanh(self.fc1(out)) # Output shape: (batch_size, 32)
|
66 |
+
out = self.fc2(out) # Output shape: (batch_size, 2)
|
67 |
+
return out
|
68 |
+
|
69 |
+
model = Net.from_pretrained("sadhaklal/custom-cnn-cifar2")
|
70 |
+
model.eval()
|
71 |
+
|
72 |
+
with torch.no_grad():
|
73 |
+
logits = model(batch)
|
74 |
+
pred = logits[0].argmax().item()
|
75 |
+
proba = torch.softmax(logits, dim=1)
|
76 |
+
|
77 |
+
print(f"Predicted class: {class_names[pred]}")
|
78 |
+
print(f"Predicted class probabilities ('airplane' vs. 'bird'): {proba[0].tolist()}")
|
79 |
+
```
|
80 |
+
|
81 |
## Metric
|
82 |
|
83 |
Accuracy on `cifar2_val`: 0.8995
|