Spaces:
Runtime error
Runtime error
pushing app
Browse files- .gitattributes +2 -0
- README.md +1 -1
- app/app.py +101 -0
- config/args.py +17 -0
- model.py +87 -0
- models/checkpoint.ckpt +3 -0
- requirements.txt +6 -0
- test_images/butterfly.jpeg +3 -0
- test_images/cat.jpeg +3 -0
- test_images/dog.jpeg +3 -0
- test_images/elephant.jpg +3 -0
- test_images/horse.jpeg +3 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: yellow
|
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.19.1
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
|
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.19.1
|
8 |
+
app_file: app/app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
app/app.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
current = os.path.dirname(os.path.realpath(__file__))
|
5 |
+
parent = os.path.dirname(current)
|
6 |
+
sys.path.append(parent)
|
7 |
+
|
8 |
+
import albumentations as A
|
9 |
+
import gradio as gr
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from albumentations.pytorch import ToTensorV2
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from model import Classifier
|
17 |
+
|
18 |
+
# Load the model
|
19 |
+
model = Classifier.load_from_checkpoint("./models/checkpoint.ckpt")
|
20 |
+
model.eval()
|
21 |
+
|
22 |
+
# Define labels
|
23 |
+
labels = [
|
24 |
+
"dog",
|
25 |
+
"horse",
|
26 |
+
"elephant",
|
27 |
+
"butterfly",
|
28 |
+
"chicken",
|
29 |
+
"cat",
|
30 |
+
"cow",
|
31 |
+
"sheep",
|
32 |
+
"spider",
|
33 |
+
"squirrel",
|
34 |
+
]
|
35 |
+
|
36 |
+
# Preprocess function
|
37 |
+
def preprocess(image):
|
38 |
+
image = np.array(image)
|
39 |
+
resize = A.Resize(224, 224)
|
40 |
+
normalize = A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
41 |
+
to_tensor = ToTensorV2()
|
42 |
+
transform = A.Compose([resize, normalize, to_tensor])
|
43 |
+
image = transform(image=image)["image"]
|
44 |
+
return image
|
45 |
+
|
46 |
+
|
47 |
+
# Define the sample images
|
48 |
+
sample_images = {
|
49 |
+
"dog": "./test_images/dog.jpeg",
|
50 |
+
"cat": "./test_images/cat.jpeg",
|
51 |
+
"butterfly": "./test_images/butterfly.jpeg",
|
52 |
+
"elephant": "./test_images/elephant.jpg",
|
53 |
+
"horse": "./test_images/horse.jpeg",
|
54 |
+
}
|
55 |
+
|
56 |
+
# Define the function to make predictions on an image
|
57 |
+
def predict(image):
|
58 |
+
try:
|
59 |
+
image = preprocess(image).unsqueeze(0)
|
60 |
+
|
61 |
+
# Prediction
|
62 |
+
# Make a prediction on the image
|
63 |
+
with torch.no_grad():
|
64 |
+
output = model(image)
|
65 |
+
# convert to probabilities
|
66 |
+
probabilities = torch.nn.functional.softmax(torch.exp(output[0]), dim=0)
|
67 |
+
topk_prob, topk_label = torch.topk(probabilities, 3)
|
68 |
+
|
69 |
+
# Return the top 3 predictions
|
70 |
+
return {labels[i]: float(probabilities[i]) for i in range(3)}
|
71 |
+
except Exception as e:
|
72 |
+
print(f"Error predicting image: {e}")
|
73 |
+
return []
|
74 |
+
|
75 |
+
|
76 |
+
# Define the interface
|
77 |
+
def app():
|
78 |
+
title = "Animal-10 Image Classification"
|
79 |
+
description = "Classify images using a custom CNN model and deploy using Gradio."
|
80 |
+
|
81 |
+
gr.Interface(
|
82 |
+
title=title,
|
83 |
+
description=description,
|
84 |
+
fn=predict,
|
85 |
+
inputs=gr.Image(type="pil"),
|
86 |
+
outputs=gr.Label(
|
87 |
+
num_top_classes=3,
|
88 |
+
),
|
89 |
+
examples=[
|
90 |
+
"./test_images/dog.jpeg",
|
91 |
+
"./test_images/cat.jpeg",
|
92 |
+
"./test_images/butterfly.jpeg",
|
93 |
+
"./test_images/elephant.jpg",
|
94 |
+
"./test_images/horse.jpeg",
|
95 |
+
],
|
96 |
+
).launch()
|
97 |
+
|
98 |
+
|
99 |
+
# Run the app
|
100 |
+
if __name__ == "__main__":
|
101 |
+
app()
|
config/args.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class Args:
|
6 |
+
"""
|
7 |
+
Training arguments.
|
8 |
+
"""
|
9 |
+
|
10 |
+
# Learning rate for the optimizer
|
11 |
+
learning_rate: float = 1e-3
|
12 |
+
# Training batch size
|
13 |
+
batch_size: int = 32
|
14 |
+
# Total numebr of classes
|
15 |
+
num_classes: int = 10
|
16 |
+
# Maximum number of training epochs
|
17 |
+
max_epochs: int = 5
|
model.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
import torchmetrics
|
4 |
+
from simple_parsing import ArgumentParser
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from config.args import Args
|
9 |
+
|
10 |
+
parser = ArgumentParser()
|
11 |
+
parser.add_arguments(Args, dest="options")
|
12 |
+
args_namespace = parser.parse_args()
|
13 |
+
args = args_namespace.options
|
14 |
+
|
15 |
+
# Model class
|
16 |
+
class Model(nn.Module):
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.conv1 = nn.Conv2d(3, 32, 5)
|
21 |
+
self.conv2 = nn.Conv2d(32, 64, 5)
|
22 |
+
self.conv3 = nn.Conv2d(64, 128, 3)
|
23 |
+
self.dropout1 = nn.Dropout2d(0.25)
|
24 |
+
self.dropout2 = nn.Dropout2d(0.5)
|
25 |
+
|
26 |
+
x = torch.randn(3, 224, 224).view(-1, 3, 224, 224)
|
27 |
+
self._to_linear = None
|
28 |
+
self.convs(x)
|
29 |
+
|
30 |
+
self.fc1 = nn.Linear(self._to_linear, 128)
|
31 |
+
self.fc2 = nn.Linear(128, args.num_classes)
|
32 |
+
|
33 |
+
def convs(self, x):
|
34 |
+
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
|
35 |
+
x = self.dropout1(x)
|
36 |
+
x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
|
37 |
+
x = self.dropout2(x)
|
38 |
+
x = F.max_pool2d(F.relu(self.conv3(x)), (2, 2))
|
39 |
+
|
40 |
+
if self._to_linear is None:
|
41 |
+
self._to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
|
42 |
+
return x
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.convs(x)
|
46 |
+
x = x.view(-1, self._to_linear)
|
47 |
+
x = F.relu(self.fc1(x))
|
48 |
+
x = self.fc2(x)
|
49 |
+
return F.log_softmax(x, dim=1)
|
50 |
+
|
51 |
+
|
52 |
+
class Classifier(pl.LightningModule):
|
53 |
+
def __init__(self):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.model = Model()
|
57 |
+
self.accuracy = torchmetrics.Accuracy(
|
58 |
+
task="multiclass", num_classes=args.num_classes
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x = self.model(x)
|
63 |
+
return x
|
64 |
+
|
65 |
+
def nll_loss(self, logits, labels):
|
66 |
+
return F.nll_loss(logits, labels)
|
67 |
+
|
68 |
+
def training_step(self, train_batch, batch_idx):
|
69 |
+
x, y = train_batch
|
70 |
+
logits = self.model(x)
|
71 |
+
loss = self.nll_loss(logits, y)
|
72 |
+
acc = self.accuracy(logits, y)
|
73 |
+
self.log("accuracy/train_accuracy", acc)
|
74 |
+
self.log("loss/train_loss", loss)
|
75 |
+
return loss
|
76 |
+
|
77 |
+
def validation_step(self, val_batch, batch_idx):
|
78 |
+
x, y = val_batch
|
79 |
+
logits = self.model(x)
|
80 |
+
loss = self.nll_loss(logits, y)
|
81 |
+
acc = self.accuracy(logits, y)
|
82 |
+
self.log("accuracy/val_accuracy", acc)
|
83 |
+
self.log("loss/val_loss", loss)
|
84 |
+
|
85 |
+
def configure_optimizers(self):
|
86 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=args.learning_rate)
|
87 |
+
return optimizer
|
models/checkpoint.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85f9ff02a03ded56ff20903f0227f017ec351a9946b7fa0a1b7c33b7107427d6
|
3 |
+
size 124442154
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
pytorch-lightning
|
4 |
+
simple-parsing
|
5 |
+
albumentations
|
6 |
+
matplotlib
|
test_images/butterfly.jpeg
ADDED
Git LFS Details
|
test_images/cat.jpeg
ADDED
Git LFS Details
|
test_images/dog.jpeg
ADDED
Git LFS Details
|
test_images/elephant.jpg
ADDED
Git LFS Details
|
test_images/horse.jpeg
ADDED
Git LFS Details
|