Spaces:
Runtime error
Runtime error
hkanumilli
commited on
Commit
•
dcb34da
1
Parent(s):
f1a2cc8
updating with final model
Browse files- .DS_Store +0 -0
- MNISTModel.pth +3 -0
- app.py +3 -2
- neural_network.py +1 -1
- train_model.py +7 -5
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
MNISTModel.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca929cef293140ec31967298cfae2f56b0b4a319d445a81cc60dd298064df843
|
3 |
+
size 1688783
|
app.py
CHANGED
@@ -9,8 +9,9 @@ transform = transforms.Compose([
|
|
9 |
transforms.Normalize((0.5,), (0.5,)) # Normalize the image
|
10 |
])
|
11 |
|
12 |
-
# Load the trained model
|
13 |
-
net =
|
|
|
14 |
LABELS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
|
15 |
|
16 |
def predict(drawing):
|
|
|
9 |
transforms.Normalize((0.5,), (0.5,)) # Normalize the image
|
10 |
])
|
11 |
|
12 |
+
# Load the trained model
|
13 |
+
net = MNISTNetwork()
|
14 |
+
net.load_state_dict(torch.load('MNISTModel.pth'))
|
15 |
LABELS = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
|
16 |
|
17 |
def predict(drawing):
|
neural_network.py
CHANGED
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
|
22 |
# return F.log_softmax(x, dim=1)
|
23 |
|
24 |
class MNISTNetwork(nn.Module):
|
25 |
-
# achieved 98.
|
26 |
def __init__(self):
|
27 |
super().__init__()
|
28 |
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
|
|
|
22 |
# return F.log_softmax(x, dim=1)
|
23 |
|
24 |
class MNISTNetwork(nn.Module):
|
25 |
+
# achieved 98.783 percent accuracy
|
26 |
def __init__(self):
|
27 |
super().__init__()
|
28 |
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
|
train_model.py
CHANGED
@@ -8,12 +8,11 @@ from neural_network import MNISTNetwork
|
|
8 |
# hyperparameters
|
9 |
BATCH_SIZE = 64
|
10 |
NUM_WORKERS = 2
|
11 |
-
EPOCH =
|
12 |
LEARNING_RATE = 0.01
|
13 |
MOMENTUM = 0.5
|
14 |
LOSS = torch.nn.CrossEntropyLoss()
|
15 |
|
16 |
-
|
17 |
## Step 1: define our transforms
|
18 |
transform = transforms.Compose(
|
19 |
[
|
@@ -24,7 +23,10 @@ transform = transforms.Compose(
|
|
24 |
|
25 |
## Step 2: get our datasets
|
26 |
full_ds = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
|
27 |
-
|
|
|
|
|
|
|
28 |
test_ds = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform)
|
29 |
|
30 |
## Step 3: create our dataloaders
|
@@ -43,7 +45,7 @@ table.field_names = ['Epoch', 'Training Loss', 'Validation Accuracy']
|
|
43 |
|
44 |
if __name__ == "__main__":
|
45 |
multiprocessing.freeze_support()
|
46 |
-
|
47 |
for e in range(EPOCH):
|
48 |
model.train()
|
49 |
running_loss = 0.0
|
@@ -82,5 +84,5 @@ if __name__ == "__main__":
|
|
82 |
test_acc = round((correct/total)*100, 3)
|
83 |
|
84 |
print(f'Test Accuracy: {test_acc}')
|
85 |
-
torch.save(model, 'MNISTModel.pth')
|
86 |
|
|
|
8 |
# hyperparameters
|
9 |
BATCH_SIZE = 64
|
10 |
NUM_WORKERS = 2
|
11 |
+
EPOCH = 15
|
12 |
LEARNING_RATE = 0.01
|
13 |
MOMENTUM = 0.5
|
14 |
LOSS = torch.nn.CrossEntropyLoss()
|
15 |
|
|
|
16 |
## Step 1: define our transforms
|
17 |
transform = transforms.Compose(
|
18 |
[
|
|
|
23 |
|
24 |
## Step 2: get our datasets
|
25 |
full_ds = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
|
26 |
+
train_size = int(0.8 * len(full_ds)) # Use 80% of the data for training
|
27 |
+
val_size = len(full_ds) - train_size # Use the remaining 20% for validation
|
28 |
+
|
29 |
+
train_ds, valid_ds = torch.utils.data.random_split(full_ds, [train_size, val_size])
|
30 |
test_ds = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform)
|
31 |
|
32 |
## Step 3: create our dataloaders
|
|
|
45 |
|
46 |
if __name__ == "__main__":
|
47 |
multiprocessing.freeze_support()
|
48 |
+
# begin training process
|
49 |
for e in range(EPOCH):
|
50 |
model.train()
|
51 |
running_loss = 0.0
|
|
|
84 |
test_acc = round((correct/total)*100, 3)
|
85 |
|
86 |
print(f'Test Accuracy: {test_acc}')
|
87 |
+
torch.save(model.state_dict(), 'MNISTModel.pth')
|
88 |
|