gauthamk commited on
Commit
f3832f8
β€’
1 Parent(s): 97d74ab

Add Gradio and Documentation

Browse files
README.md CHANGED
@@ -10,4 +10,19 @@ pinned: false
10
  license: gpl-3.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: gpl-3.0
11
  ---
12
 
13
+ # Pytorch Resnet34 Bird Classification
14
+
15
+ The project is an implementation of the ResNet34 model as per the [Microsoft Research Paper](https://arxiv.org/abs/1512.03385). The model is build using PyTorch and is trained on the [Birds Classification Dataset](https://www.kaggle.com/datasets/gpiosenka/100-bird-species) from Kaggle.
16
+
17
+ ## πŸš€ Getting Started
18
+
19
+ All the code for training the model and exporting to ONNX format is present in the [notebook](notebooks) folder or you can use this [Kaggle Notebook](https://www.kaggle.com/gauthamkrishnan119/pytorch-resnet34-birds-classification) for training the model. It took ~1.5 hours to train the model on the complete dataset using a P100 GPU. The [app.py](app.py) file contains the code for deploying the model using Gradio.
20
+
21
+ ## πŸ€— Demo
22
+
23
+ You can try out the model on [Hugging Face Spaces](https://huggingface.co/spaces/gauthamk/pytorch-resnet34-bird-classification)
24
+
25
+ ## πŸ–₯️ Sample Interface
26
+
27
+ ![Sample Inference](samples/sample1.png)
28
+ ![Sample Inference](samples/sample2.png)
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from functions import *
4
+
5
+ examples_dir = 'examples'
6
+ title = "Birds Classification - ResNet34 PyTorch"
7
+ examples = [os.path.join(examples_dir, i) for i in os.listdir('examples')]
8
+
9
+ interface = gr.Interface(fn=predict, inputs=gr.Image(type= 'numpy', shape=(224, 224)).style(height= 256),
10
+ outputs= gr.Label(num_top_classes= 5),
11
+ examples= examples, title= title, css= '.gr-box {background-color: rgb(230 230 230);}')
12
+
13
+ interface.launch()
examples/AZURE JAY.jpg ADDED
examples/IBERIAN MAGPIE.jpg ADDED
examples/LILAC ROLLER.jpg ADDED
examples/MALAGASY WHITE EYE.jpg ADDED
examples/NORTHERN PARULA.jpg ADDED
functions.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import onnxruntime as rt
4
+
5
+ model_path = 'model/model.onnx'
6
+ class_path = 'model/birds_name_mapping.json'
7
+
8
+ normalise_means = [0.4914, 0.4822, 0.4465]
9
+ normalise_stds = [0.2023, 0.1994, 0.2010]
10
+
11
+ def normalise_image(image):
12
+ image = image.copy()
13
+ for i in range(3):
14
+ image[:, i, :, :] = (image[:, i, :, :] - normalise_means[i]) / normalise_stds[i]
15
+ return image
16
+
17
+ def load_class_names():
18
+ with open(class_path, 'r') as f:
19
+ class_names = json.load(f)
20
+ return class_names
21
+
22
+ def predict(inp_image):
23
+
24
+ class_names = load_class_names()
25
+
26
+ image = inp_image
27
+ image = image.transpose((2, 0, 1))
28
+
29
+ image = image / 255.0
30
+ image = np.expand_dims(image, axis=0)
31
+ image = normalise_image(image)
32
+ image = image.astype(np.float32)
33
+
34
+ sess = rt.InferenceSession(model_path)
35
+
36
+ input_name = sess.get_inputs()[0].name
37
+ output_name = sess.get_outputs()[0].name
38
+
39
+ output = sess.run([output_name], {input_name: image})[0]
40
+ prob = np.exp(output) / np.sum(np.exp(output), axis=1, keepdims=True)
41
+
42
+ top5 = np.argsort(prob[0])[-5:][::-1]
43
+
44
+ class_probs = {class_names[str(i)]: float(prob[0][i]) for i in top5}
45
+ print(class_probs)
46
+
47
+ return class_probs
samples/sample1.png ADDED
samples/sample2.png ADDED