Vedmani commited on
Commit
6b70dc2
1 Parent(s): 5e63017

added app.py

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: AvianVision
3
- emoji: 👁
4
  colorFrom: pink
5
  colorTo: purple
6
  sdk: gradio
 
1
  ---
2
  title: AvianVision
3
+ emoji:
4
  colorFrom: pink
5
  colorTo: purple
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import EfficientNet
2
+ from utils import get_device
3
+ import torch
4
+ import json
5
+ import gradio as gr
6
+ import torch
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import json
10
+ import timm
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ def load_efficientnet_model(model_path: str, device=get_device()):
15
+ """
16
+ Load a PyTorch model checkpoint.
17
+
18
+ Args:
19
+ model_path: The path of the checkpoint file.
20
+ device: The device to load the model onto.
21
+
22
+ Returns:
23
+ The model loaded onto the specified device.
24
+ """
25
+ # Initialize model
26
+ model = EfficientNet()
27
+
28
+ # Load model weights onto the specified device
29
+ model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])
30
+
31
+ # Set model to evaluation mode
32
+ model.eval()
33
+
34
+ return model
35
+
36
+ with open('idx_to_class.json', 'r') as f:
37
+ idx_to_class = json.load(f)
38
+
39
+
40
+ def predict_image(array):
41
+ """
42
+ Predict the class of an image.
43
+
44
+ Args:
45
+ array: The image data as an array.
46
+
47
+ Returns:
48
+ The predicted class.
49
+ """
50
+ # Convert the image to a PIL Image object
51
+ input_image = Image.fromarray(array)
52
+
53
+ # Load the model
54
+ model = load_efficientnet_model('/home/vedmani/Downloads/efficientnet_epoch=18_loss=0.0020_val_f1score=0.8993.pth')
55
+
56
+ # Transform the image
57
+ transform = transforms.Compose([
58
+ transforms.Resize(size=(150, 150)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
61
+ ])
62
+ image = transform(input_image).unsqueeze(0)
63
+ image.to(get_device())
64
+
65
+ # Predict the class
66
+ with torch.no_grad():
67
+ output = model(image)
68
+ # Apply softmax to the outputs to convert them into probabilities
69
+ probabilities = F.softmax(output, dim=1)
70
+ predicted = probabilities.argmax().item()
71
+ predicted_class = idx_to_class[str(predicted)] # Make sure your keys in json are string type
72
+
73
+ return predicted_class
74
+
75
+
76
+ # Create the image classifier
77
+ image_classifier = gr.Interface(fn=predict_image, inputs="image", outputs="text", allow_flagging='Never')
78
+
79
+ # Launch the image classifier
80
+ image_classifier.launch(share=True)
efficientnet_epoch=18_loss=0.0020_val_f1score=0.8993.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86679a9101ac637ac321200f69d807b92d8c7419879111a8f598cd24d7987445
3
+ size 49005075
idx_to_class.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": "Asian-Green-Bee-Eater", "1": "Brown-Headed-Barbet", "2": "Cattle-Egret", "3": "Common-Kingfisher", "4": "Common-Myna", "5": "Common-Rosefinch", "6": "Common-Tailorbird", "7": "Coppersmith-Barbet", "8": "Forest-Wagtail", "9": "Gray-Wagtail", "10": "Hoopoe", "11": "House-Crow", "12": "Indian-Grey-Hornbill", "13": "Indian-Peacock", "14": "Indian-Pitta", "15": "Indian-Roller", "16": "Jungle-Babbler", "17": "Northern-Lapwing", "18": "Red-Wattled-Lapwing", "19": "Ruddy-Shelduck", "20": "Rufous-Treepie", "21": "Sarus-Crane", "22": "White-Breasted-Kingfisher", "23": "White-Breasted-Waterhen", "24": "White-Wagtail"}