pytholic commited on
Commit
4084d4a
1 Parent(s): 9e079e1

pushing app

Browse files
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: yellow
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.19.1
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.19.1
8
+ app_file: app/app.py
9
  pinned: false
10
  license: mit
11
  ---
app/app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ current = os.path.dirname(os.path.realpath(__file__))
5
+ parent = os.path.dirname(current)
6
+ sys.path.append(parent)
7
+
8
+ import albumentations as A
9
+ import gradio as gr
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+ from albumentations.pytorch import ToTensorV2
14
+ from PIL import Image
15
+
16
+ from model import Classifier
17
+
18
+ # Load the model
19
+ model = Classifier.load_from_checkpoint("./models/checkpoint.ckpt")
20
+ model.eval()
21
+
22
+ # Define labels
23
+ labels = [
24
+ "dog",
25
+ "horse",
26
+ "elephant",
27
+ "butterfly",
28
+ "chicken",
29
+ "cat",
30
+ "cow",
31
+ "sheep",
32
+ "spider",
33
+ "squirrel",
34
+ ]
35
+
36
+ # Preprocess function
37
+ def preprocess(image):
38
+ image = np.array(image)
39
+ resize = A.Resize(224, 224)
40
+ normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
41
+ to_tensor = ToTensorV2()
42
+ transform = A.Compose([resize, normalize, to_tensor])
43
+ image = transform(image=image)["image"]
44
+ return image
45
+
46
+
47
+ # Define the sample images
48
+ sample_images = {
49
+ "dog": "./test_images/dog.jpeg",
50
+ "cat": "./test_images/cat.jpeg",
51
+ "butterfly": "./test_images/butterfly.jpeg",
52
+ "elephant": "./test_images/elephant.jpg",
53
+ "horse": "./test_images/horse.jpeg",
54
+ }
55
+
56
+ # Define the function to make predictions on an image
57
+ def predict(image):
58
+ try:
59
+ image = preprocess(image).unsqueeze(0)
60
+
61
+ # Prediction
62
+ # Make a prediction on the image
63
+ with torch.no_grad():
64
+ output = model(image)
65
+ # convert to probabilities
66
+ probabilities = torch.nn.functional.softmax(torch.exp(output[0]), dim=0)
67
+ topk_prob, topk_label = torch.topk(probabilities, 3)
68
+
69
+ # Return the top 3 predictions
70
+ return {labels[i]: float(probabilities[i]) for i in range(3)}
71
+ except Exception as e:
72
+ print(f"Error predicting image: {e}")
73
+ return []
74
+
75
+
76
+ # Define the interface
77
+ def app():
78
+ title = "Animal-10 Image Classification"
79
+ description = "Classify images using a custom CNN model and deploy using Gradio."
80
+
81
+ gr.Interface(
82
+ title=title,
83
+ description=description,
84
+ fn=predict,
85
+ inputs=gr.Image(type="pil"),
86
+ outputs=gr.Label(
87
+ num_top_classes=3,
88
+ ),
89
+ examples=[
90
+ "./test_images/dog.jpeg",
91
+ "./test_images/cat.jpeg",
92
+ "./test_images/butterfly.jpeg",
93
+ "./test_images/elephant.jpg",
94
+ "./test_images/horse.jpeg",
95
+ ],
96
+ ).launch()
97
+
98
+
99
+ # Run the app
100
+ if __name__ == "__main__":
101
+ app()
config/args.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class Args:
6
+ """
7
+ Training arguments.
8
+ """
9
+
10
+ # Learning rate for the optimizer
11
+ learning_rate: float = 1e-3
12
+ # Training batch size
13
+ batch_size: int = 32
14
+ # Total numebr of classes
15
+ num_classes: int = 10
16
+ # Maximum number of training epochs
17
+ max_epochs: int = 5
model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torchmetrics
4
+ from simple_parsing import ArgumentParser
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from config.args import Args
9
+
10
+ parser = ArgumentParser()
11
+ parser.add_arguments(Args, dest="options")
12
+ args_namespace = parser.parse_args()
13
+ args = args_namespace.options
14
+
15
+ # Model class
16
+ class Model(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ self.conv1 = nn.Conv2d(3, 32, 5)
21
+ self.conv2 = nn.Conv2d(32, 64, 5)
22
+ self.conv3 = nn.Conv2d(64, 128, 3)
23
+ self.dropout1 = nn.Dropout2d(0.25)
24
+ self.dropout2 = nn.Dropout2d(0.5)
25
+
26
+ x = torch.randn(3, 224, 224).view(-1, 3, 224, 224)
27
+ self._to_linear = None
28
+ self.convs(x)
29
+
30
+ self.fc1 = nn.Linear(self._to_linear, 128)
31
+ self.fc2 = nn.Linear(128, args.num_classes)
32
+
33
+ def convs(self, x):
34
+ x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
35
+ x = self.dropout1(x)
36
+ x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
37
+ x = self.dropout2(x)
38
+ x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
39
+
40
+ if self._to_linear is None:
41
+ self._to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
42
+ return x
43
+
44
+ def forward(self, x):
45
+ x = self.convs(x)
46
+ x = x.view(-1, self._to_linear)
47
+ x = F.relu(self.fc1(x))
48
+ x = self.fc2(x)
49
+ return F.log_softmax(x, dim=1)
50
+
51
+
52
+ class Classifier(pl.LightningModule):
53
+ def __init__(self):
54
+ super().__init__()
55
+
56
+ self.model = Model()
57
+ self.accuracy = torchmetrics.Accuracy(
58
+ task="multiclass", num_classes=args.num_classes
59
+ )
60
+
61
+ def forward(self, x):
62
+ x = self.model(x)
63
+ return x
64
+
65
+ def nll_loss(self, logits, labels):
66
+ return F.nll_loss(logits, labels)
67
+
68
+ def training_step(self, train_batch, batch_idx):
69
+ x, y = train_batch
70
+ logits = self.model(x)
71
+ loss = self.nll_loss(logits, y)
72
+ acc = self.accuracy(logits, y)
73
+ self.log("accuracy/train_accuracy", acc)
74
+ self.log("loss/train_loss", loss)
75
+ return loss
76
+
77
+ def validation_step(self, val_batch, batch_idx):
78
+ x, y = val_batch
79
+ logits = self.model(x)
80
+ loss = self.nll_loss(logits, y)
81
+ acc = self.accuracy(logits, y)
82
+ self.log("accuracy/val_accuracy", acc)
83
+ self.log("loss/val_loss", loss)
84
+
85
+ def configure_optimizers(self):
86
+ optimizer = torch.optim.Adam(self.parameters(), lr=args.learning_rate)
87
+ return optimizer
models/checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85f9ff02a03ded56ff20903f0227f017ec351a9946b7fa0a1b7c33b7107427d6
3
+ size 124442154
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pytorch-lightning
4
+ simple-parsing
5
+ albumentations
6
+ matplotlib
test_images/butterfly.jpeg ADDED

Git LFS Details

  • SHA256: bc9af0c10dd0c0e9cdd878ad5af02795a820a35e4f2e87e6d32e697b87d9fd35
  • Pointer size: 129 Bytes
  • Size of remote file: 9.73 kB
test_images/cat.jpeg ADDED

Git LFS Details

  • SHA256: f7aed3722268778683e9728e3b3170fe43deb08801bf74f8d6e580ba760451c9
  • Pointer size: 130 Bytes
  • Size of remote file: 11.9 kB
test_images/dog.jpeg ADDED

Git LFS Details

  • SHA256: 5f3586190eaafb3574cd767392df808e1bd933ea81ada9b7d36a19b03e67fc17
  • Pointer size: 129 Bytes
  • Size of remote file: 5.91 kB
test_images/elephant.jpg ADDED

Git LFS Details

  • SHA256: dae644dbbfe1ae39176f8afeb2c66ed9ee68717a2f3097f71ee713821c9d18e6
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
test_images/horse.jpeg ADDED

Git LFS Details

  • SHA256: 6e25581a57b2a47ec47cd20bcd355ff3dcd9c02b9449d556d321342e9c94961c
  • Pointer size: 130 Bytes
  • Size of remote file: 15.9 kB