sadhaklal commited on
Commit
d6b7c50
1 Parent(s): 91e97d6

updated the "Usage" section of README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -3
README.md CHANGED
@@ -40,11 +40,11 @@ tfms = v2.Compose([
40
  v2.ToDtype(torch.float32, scale=True)
41
  ])
42
 
 
 
43
  import torch.nn as nn
44
  from huggingface_hub import PyTorchModelHubMixin
45
 
46
- device = torch.device("cpu")
47
-
48
  class MLP(nn.Module, PyTorchModelHubMixin):
49
  def __init__(self):
50
  super().__init__()
@@ -62,6 +62,12 @@ model = MLP.from_pretrained("sadhaklal/mlp-fashion-mnist")
62
  model.to(device)
63
 
64
  example = fashion_mnist['test'][0]
 
 
 
 
 
 
65
  img = tfms(example['image'])
66
  x_batch = img.unsqueeze(0)
67
 
@@ -72,7 +78,7 @@ with torch.no_grad():
72
  proba = torch.softmax(logits, dim=-1)
73
 
74
  confidence, pred = proba.max(dim=-1)
75
- print(f"Predicted class: {pred[0].item()}")
76
  print(f"Predicted confidence: {round(confidence[0].item(), 4)}")
77
  ```
78
 
 
40
  v2.ToDtype(torch.float32, scale=True)
41
  ])
42
 
43
+ device = torch.device("cpu")
44
+
45
  import torch.nn as nn
46
  from huggingface_hub import PyTorchModelHubMixin
47
 
 
 
48
  class MLP(nn.Module, PyTorchModelHubMixin):
49
  def __init__(self):
50
  super().__init__()
 
62
  model.to(device)
63
 
64
  example = fashion_mnist['test'][0]
65
+
66
+ import matplotlib.pyplot as plt
67
+
68
+ plt.imshow(example['image'], cmap='gray')
69
+ print(f"Ground truth: {id2label[example['label']]}")
70
+
71
  img = tfms(example['image'])
72
  x_batch = img.unsqueeze(0)
73
 
 
78
  proba = torch.softmax(logits, dim=-1)
79
 
80
  confidence, pred = proba.max(dim=-1)
81
+ print(f"Predicted class: {id2label[pred[0].item()]}")
82
  print(f"Predicted confidence: {round(confidence[0].item(), 4)}")
83
  ```
84