sadhaklal commited on
Commit
8b8b88d
1 Parent(s): 9f0be86

added "Usage" section to README.md

Browse files
Files changed (1) hide show
  1. README.md +59 -0
README.md CHANGED
@@ -19,6 +19,65 @@ Code: https://github.com/sambitmukherjee/dlwpt-exercises/blob/main/chapter_8/exe
19
 
20
  Experiment tracking: https://wandb.ai/sadhaklal/custom-cnn-cifar2
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ## Metric
23
 
24
  Accuracy on `cifar2_val`: 0.8995
 
19
 
20
  Experiment tracking: https://wandb.ai/sadhaklal/custom-cnn-cifar2
21
 
22
+ ## Usage
23
+
24
+ ```
25
+ !pip install -q datasets
26
+
27
+ from datasets import load_dataset
28
+
29
+ cifar10 = load_dataset("cifar10")
30
+ label_map = {0: 0, 2: 1}
31
+ class_names = ['airplane', 'bird']
32
+ cifar2_train = [(example['img'], label_map[example['label']]) for example in cifar10['train'] if example['label'] in [0, 2]]
33
+ cifar2_val = [(example['img'], label_map[example['label']]) for example in cifar10['test'] if example['label'] in [0, 2]]
34
+
35
+ example = cifar2_val[0]
36
+ img, label = example
37
+
38
+ import torch
39
+ from torchvision.transforms import v2
40
+
41
+ tfms = v2.Compose([
42
+ v2.ToImage(),
43
+ v2.ToDtype(torch.float32, scale=True),
44
+ v2.Normalize(mean=[0.4915, 0.4823, 0.4468], std=[0.2470, 0.2435, 0.2616])
45
+ ])
46
+ img = tfms(img)
47
+ batch = img.unsqueeze(0)
48
+
49
+ import torch.nn as nn
50
+ import torch.nn.functional as F
51
+ from huggingface_hub import PyTorchModelHubMixin
52
+
53
+ class Net(nn.Module, PyTorchModelHubMixin):
54
+ def __init__(self):
55
+ super().__init__()
56
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=1)
57
+ self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1, stride=1)
58
+ self.fc1 = nn.Linear(8 * 8 * 8, 32)
59
+ self.fc2 = nn.Linear(32, 2)
60
+
61
+ def forward(self, x):
62
+ out = F.max_pool2d(torch.tanh(self.conv1(x)), kernel_size=2, stride=2) # Output shape: (batch_size, 16, 16, 16)
63
+ out = F.max_pool2d(torch.tanh(self.conv2(out)), kernel_size=2, stride=2) # Output shape: (batch_size, 8, 8, 8)
64
+ out = out.view(-1, 8 * 8 * 8) # Output shape: (batch_size, 512)
65
+ out = torch.tanh(self.fc1(out)) # Output shape: (batch_size, 32)
66
+ out = self.fc2(out) # Output shape: (batch_size, 2)
67
+ return out
68
+
69
+ model = Net.from_pretrained("sadhaklal/custom-cnn-cifar2")
70
+ model.eval()
71
+
72
+ with torch.no_grad():
73
+ logits = model(batch)
74
+ pred = logits[0].argmax().item()
75
+ proba = torch.softmax(logits, dim=1)
76
+
77
+ print(f"Predicted class: {class_names[pred]}")
78
+ print(f"Predicted class probabilities ('airplane' vs. 'bird'): {proba[0].tolist()}")
79
+ ```
80
+
81
  ## Metric
82
 
83
  Accuracy on `cifar2_val`: 0.8995