Svenni551 commited on
Commit
e91d5c1
1 Parent(s): e36420b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +65 -9
README.md CHANGED
@@ -44,30 +44,86 @@ Below is a simple example of how to load `MyNeuralNet` and use it to predict MNI
44
  ```python
45
  import torch
46
  import torch.nn as nn
47
- from torch import load
 
48
  from huggingface_hub import hf_hub_download
49
 
 
 
 
 
 
 
 
50
  class MyNeuralNet(nn.Module):
51
  def __init__(self):
52
  super(MyNeuralNet, self).__init__()
53
- self.Matrix1 = nn.Linear(28*28, 100)
54
  self.Matrix2 = nn.Linear(100, 50)
55
  self.Matrix3 = nn.Linear(50, 10)
56
  self.R = nn.ReLU()
57
-
58
  def forward(self, x):
59
- x = x.view(-1, 28*28)
60
  x = self.R(self.Matrix1(x))
61
  x = self.R(self.Matrix2(x))
62
  x = self.Matrix3(x)
63
  return x.squeeze()
64
 
65
- model_state_dict = load(hf_hub_download(repo_id="Svenni551/may-mnist-digits", filename="model.pth"), map_location=torch.device('cpu'))
66
- model = MyNeuralNet()
67
- model.load_state_dict(model_state_dict)
68
- model.eval()
69
 
70
- # Use 'model' for predictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ```
72
 
73
  ## Performance
 
44
  ```python
45
  import torch
46
  import torch.nn as nn
47
+ import torch.nn.functional as F
48
+ from torch.utils.data import Dataset, DataLoader
49
  from huggingface_hub import hf_hub_download
50
 
51
+
52
+ # Ensure the device selection logic is centralized
53
+ def get_device():
54
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+
56
+
57
+ # Define the neural network architecture
58
  class MyNeuralNet(nn.Module):
59
  def __init__(self):
60
  super(MyNeuralNet, self).__init__()
61
+ self.Matrix1 = nn.Linear(28 * 28, 100)
62
  self.Matrix2 = nn.Linear(100, 50)
63
  self.Matrix3 = nn.Linear(50, 10)
64
  self.R = nn.ReLU()
65
+
66
  def forward(self, x):
67
+ x = x.view(-1, 28 * 28)
68
  x = self.R(self.Matrix1(x))
69
  x = self.R(self.Matrix2(x))
70
  x = self.Matrix3(x)
71
  return x.squeeze()
72
 
 
 
 
 
73
 
74
+ # Define the custom dataset class
75
+ class CTDataset(Dataset):
76
+ def __init__(self, filepath, device):
77
+ # Add 'device' as a parameter to the class constructor
78
+ x, y = torch.load(filepath)
79
+ self.x = x.float().div(255).to(device) # Use the passed 'device' for tensor operations
80
+ self.y = F.one_hot(y, num_classes=10).float().to(device)
81
+
82
+ def __len__(self):
83
+ return self.x.shape[0]
84
+
85
+ def __getitem__(self, ix):
86
+ return self.x[ix], self.y[ix]
87
+
88
+
89
+ def load_model():
90
+ device = get_device()
91
+ model_state_dict = torch.load(hf_hub_download(repo_id="Svenni551/may-mnist-digits", filename="model.pth"),
92
+ map_location=torch.device(device))
93
+ model = MyNeuralNet().to(device)
94
+ model.load_state_dict(model_state_dict)
95
+ model.eval()
96
+ return model
97
+
98
+
99
+ def predict(input_data):
100
+ device = get_device()
101
+ model = load_model()
102
+ if isinstance(input_data, str): # Assuming filepath to dataset
103
+ dataset = CTDataset(input_data, device) # Pass 'device' as an argument
104
+ loader = DataLoader(dataset, batch_size=32, shuffle=False)
105
+ predictions = []
106
+ with torch.no_grad():
107
+ for batch, _ in loader:
108
+ yhat = model(batch).argmax(axis=1).cpu().numpy()
109
+ predictions.extend(yhat)
110
+ return predictions
111
+ elif isinstance(input_data, torch.Tensor):
112
+ if len(input_data.shape) == 3: # Single image
113
+ input_data = input_data.unsqueeze(0) # Add batch dimension
114
+ input_data = input_data.to(device)
115
+ with torch.no_grad():
116
+ prediction = model(input_data).argmax(axis=1).item()
117
+ return prediction
118
+ else:
119
+ raise ValueError("Unsupported input type. Provide a file path to a dataset or a PyTorch Tensor.")
120
+
121
+ # Example usage:
122
+ # prediction = predict('path/to/your/dataset.pt')
123
+ # or for an image:
124
+ # prediction = predict(your_image_tensor)
125
+
126
+ # print(prediction)
127
  ```
128
 
129
  ## Performance