Spaces:
Sleeping
Sleeping
DataRaptor
commited on
Commit
·
1556762
1
Parent(s):
0e7de7e
Upload 28 files
Browse files- ModelClass.py +119 -0
- app.py +101 -48
- classifier_weights.pth +3 -0
- inputs/Image_102.jpg +0 -0
- inputs/Image_129.jpg +0 -0
- inputs/Image_131.jpg +0 -0
- inputs/Image_18.jpg +0 -0
- inputs/Image_180.jpg +0 -0
- inputs/Image_183.jpg +0 -0
- inputs/Image_202.jpg +0 -0
- inputs/Image_207.jpg +0 -0
- inputs/Image_208.jpg +0 -0
- inputs/Image_209.jpg +0 -0
- inputs/Image_213.jpg +0 -0
- inputs/Image_222.jpg +0 -0
- inputs/Image_267.jpg +0 -0
- inputs/Image_28.jpg +0 -0
- inputs/Image_33.jpg +0 -0
- inputs/Image_35.jpg +0 -0
- inputs/Image_409.jpg +0 -0
- inputs/Image_41.jpg +0 -0
- inputs/Image_416.jpg +0 -0
- inputs/Image_417.jpg +0 -0
- inputs/Image_42.jpg +0 -0
- inputs/Image_516.jpg +0 -0
- inputs/Image_527.jpg +0 -0
- inputs/Image_58.jpg +0 -0
- inputs/Image_7.jpg +0 -0
ModelClass.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch import nn, optim
|
4 |
+
from torchvision import transforms, models
|
5 |
+
#from torch_snippets import *
|
6 |
+
#from torch.utils.data import DataLoader, Dataset
|
7 |
+
#from torchsummary import summary
|
8 |
+
|
9 |
+
#import seaborn as sns
|
10 |
+
#import matplotlib.pyplot as plt
|
11 |
+
#from sklearn.model_selection import train_test_split
|
12 |
+
from PIL import Image
|
13 |
+
#import numpy as np
|
14 |
+
#import cv2
|
15 |
+
#from glob import glob
|
16 |
+
#import pandas as pd
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
class ActionClassifier(nn.Module):
|
24 |
+
def __init__(self, ntargets):
|
25 |
+
super().__init__()
|
26 |
+
resnet = models.resnet50(pretrained=True, progress=True)
|
27 |
+
modules = list(resnet.children())[:-1] # delete last layer
|
28 |
+
self.resnet = nn.Sequential(*modules)
|
29 |
+
for param in self.resnet.parameters():
|
30 |
+
param.requires_grad = False
|
31 |
+
self.fc = nn.Sequential(
|
32 |
+
nn.Flatten(),
|
33 |
+
nn.BatchNorm1d(resnet.fc.in_features),
|
34 |
+
nn.Dropout(0.2),
|
35 |
+
nn.Linear(resnet.fc.in_features, 256),
|
36 |
+
nn.ReLU(),
|
37 |
+
nn.BatchNorm1d(256),
|
38 |
+
nn.Dropout(0.2),
|
39 |
+
nn.Linear(256, ntargets)
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
x = self.resnet(x)
|
44 |
+
x = self.fc(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
def get_transform():
|
50 |
+
transform = transforms.Compose([
|
51 |
+
transforms.Resize([224, 244]),
|
52 |
+
transforms.ToTensor(),
|
53 |
+
# std multiply by 255 to convert img of [0, 255]
|
54 |
+
# to img of [0, 1]
|
55 |
+
transforms.Normalize((0.485, 0.456, 0.406),
|
56 |
+
(0.229*255, 0.224*255, 0.225*255))]
|
57 |
+
)
|
58 |
+
return transform
|
59 |
+
|
60 |
+
|
61 |
+
def get_model():
|
62 |
+
model = ActionClassifier(15)
|
63 |
+
model.load_state_dict(torch.load('./classifier_weights.pth', map_location=torch.device('cpu')))
|
64 |
+
return model
|
65 |
+
|
66 |
+
|
67 |
+
def get_class(index):
|
68 |
+
ind2cat = [
|
69 |
+
'calling',
|
70 |
+
'clapping',
|
71 |
+
'cycling',
|
72 |
+
'dancing',
|
73 |
+
'drinking',
|
74 |
+
'eating',
|
75 |
+
'fighting',
|
76 |
+
'hugging',
|
77 |
+
'laughing',
|
78 |
+
'listening_to_music',
|
79 |
+
'running',
|
80 |
+
'sitting',
|
81 |
+
'sleeping',
|
82 |
+
'texting',
|
83 |
+
'using_laptop'
|
84 |
+
]
|
85 |
+
return ind2cat[index]
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
# img = Image.open('./inputs/Image_102.jpg').convert('RGB')
|
95 |
+
|
96 |
+
# #print(transform(img))
|
97 |
+
|
98 |
+
# img = transform(img)
|
99 |
+
|
100 |
+
# img = img.unsqueeze(dim=0)
|
101 |
+
# print(img.shape)
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
# model.eval()
|
109 |
+
# with torch.no_grad():
|
110 |
+
# out = model(img)
|
111 |
+
# out = nn.Softmax()(out).squeeze()
|
112 |
+
# print(out.shape)
|
113 |
+
# res = torch.argmax(out)
|
114 |
+
|
115 |
+
# print(ind2cat[res])
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
app.py
CHANGED
@@ -2,6 +2,58 @@ import streamlit as st
|
|
2 |
import numpy as np
|
3 |
from PIL import Image
|
4 |
import requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
# fix sidebar
|
@@ -19,6 +71,10 @@ st.markdown("""
|
|
19 |
margin: 0.125rem 2rem;
|
20 |
background-color: rgb(181 197 227 / 18%) !important;
|
21 |
}
|
|
|
|
|
|
|
|
|
22 |
</style>
|
23 |
""", unsafe_allow_html=True
|
24 |
)
|
@@ -29,12 +85,10 @@ hide_st_style = """
|
|
29 |
header {visibility: hidden;}
|
30 |
</style>
|
31 |
"""
|
32 |
-
st.markdown(hide_st_style, unsafe_allow_html=True)
|
33 |
|
34 |
|
35 |
|
36 |
-
|
37 |
-
# Function to load and predict image
|
38 |
def predict(image):
|
39 |
# Dummy prediction
|
40 |
classes = ['cat', 'dog']
|
@@ -42,53 +96,52 @@ def predict(image):
|
|
42 |
prediction /= np.sum(prediction)
|
43 |
return dict(zip(classes, prediction))
|
44 |
|
45 |
-
|
46 |
-
#st.set_page_config(page_title='Image Classification App', page_icon=':camera:', layout='wide')
|
47 |
-
st.title('HappyWhale')
|
48 |
-
st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai/<username>/<project_name>?workspace=user-<username>)")
|
49 |
-
|
50 |
-
st.markdown('This project aims to identify whales and dolphins by their unique characteristics. It can help researchers understand their behavior, population dynamics, and migration patterns. This project can aid researchers in identifying these marine mammals, providing valuable data for conservation efforts. [[Source Code]](https://kaggle.com/)')
|
51 |
-
|
52 |
-
|
53 |
-
# Add file uploader
|
54 |
-
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
|
55 |
-
|
56 |
-
# Add test image selector
|
57 |
-
test_images = {
|
58 |
-
'Cat': 'https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg',
|
59 |
-
'Dog': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/6e/Golde33443.jpg/1200px-Golde33443.jpg',
|
60 |
-
'Bird': 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/0a/Scarlet_Tanager_-_male_%28cropped%29.jpg/1200px-Scarlet_Tanager_-_male_%28cropped%29.jpg'
|
61 |
-
}
|
62 |
-
test_image = st.selectbox('Or choose a test image', list(test_images.keys()))
|
63 |
-
|
64 |
-
st.subheader('Selected Image')
|
65 |
-
# Define layout of app
|
66 |
-
left_column, right_column = st.columns([1, 2.5], gap="medium")
|
67 |
-
with left_column:
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
77 |
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
prediction =
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
-
|
|
|
|
2 |
import numpy as np
|
3 |
from PIL import Image
|
4 |
import requests
|
5 |
+
import ModelClass
|
6 |
+
from glob import glob
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
@st.cache_resource
|
11 |
+
def load_model():
|
12 |
+
return ModelClass.get_model()
|
13 |
+
|
14 |
+
@st.cache_data
|
15 |
+
def get_images():
|
16 |
+
l = glob('./inputs/*')
|
17 |
+
l = {i.split('/')[-1]: i for i in l}
|
18 |
+
return l
|
19 |
+
|
20 |
+
|
21 |
+
def infer(img):
|
22 |
+
image = img.convert('RGB')
|
23 |
+
image = ModelClass.get_transform()(image)
|
24 |
+
image = image.unsqueeze(dim=0)
|
25 |
+
|
26 |
+
model = load_model()
|
27 |
+
model.eval()
|
28 |
+
with torch.no_grad():
|
29 |
+
out = model(image)
|
30 |
+
out = nn.Softmax()(out).squeeze()
|
31 |
+
return out
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
st.set_page_config(
|
37 |
+
page_title="Whale Identification",
|
38 |
+
page_icon="🧊",
|
39 |
+
layout="centered",
|
40 |
+
initial_sidebar_state="expanded",
|
41 |
+
menu_items={
|
42 |
+
'Get Help': 'https://www.extremelycoolapp.com/help',
|
43 |
+
'Report a bug': "https://www.extremelycoolapp.com/bug",
|
44 |
+
'About': """
|
45 |
+
# This is a header. This is an *extremely* cool app!
|
46 |
+
How how are you doin.
|
47 |
+
|
48 |
+
---
|
49 |
+
I am fine
|
50 |
+
|
51 |
+
|
52 |
+
<style>
|
53 |
+
</style>
|
54 |
+
"""
|
55 |
+
}
|
56 |
+
)
|
57 |
|
58 |
|
59 |
# fix sidebar
|
|
|
71 |
margin: 0.125rem 2rem;
|
72 |
background-color: rgb(181 197 227 / 18%) !important;
|
73 |
}
|
74 |
+
.css-1y4p8pa {
|
75 |
+
padding: 3rem 1rem 10rem;
|
76 |
+
max-width: 58rem;
|
77 |
+
}
|
78 |
</style>
|
79 |
""", unsafe_allow_html=True
|
80 |
)
|
|
|
85 |
header {visibility: hidden;}
|
86 |
</style>
|
87 |
"""
|
88 |
+
#st.markdown(hide_st_style, unsafe_allow_html=True)
|
89 |
|
90 |
|
91 |
|
|
|
|
|
92 |
def predict(image):
|
93 |
# Dummy prediction
|
94 |
classes = ['cat', 'dog']
|
|
|
96 |
prediction /= np.sum(prediction)
|
97 |
return dict(zip(classes, prediction))
|
98 |
|
99 |
+
def app():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
st.title('ActionNet')
|
102 |
+
st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai/<username>/<project_name>?workspace=user-<username>)")
|
103 |
+
st.markdown('This project aims to identify whales and dolphins by their unique characteristics. It can help researchers understand their behavior, population dynamics, and migration patterns. This project can aid researchers in identifying these marine mammals, providing valuable data for conservation efforts. [[Source Code]](https://kaggle.com/)')
|
104 |
+
|
105 |
+
|
106 |
+
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
|
107 |
+
|
108 |
+
test_images = get_images()
|
109 |
+
test_image = st.selectbox('Or choose a test image', list(test_images.keys()))
|
110 |
|
111 |
+
|
112 |
+
st.subheader('Selected Image')
|
113 |
+
|
114 |
+
left_column, right_column = st.columns([1.5, 2.5], gap="medium")
|
115 |
+
with left_column:
|
116 |
|
117 |
+
if uploaded_file is not None:
|
118 |
+
image = Image.open(uploaded_file)
|
119 |
+
st.image(image, use_column_width=True)
|
120 |
+
else:
|
121 |
+
image_url = test_images[test_image]
|
122 |
+
image = Image.open(image_url)
|
123 |
+
st.image(image, use_column_width=True)
|
124 |
+
|
125 |
|
126 |
+
if st.button('✨ Get prediction from AI', type='primary'):
|
127 |
+
spacer = st.empty()
|
128 |
+
|
129 |
+
|
130 |
+
res = infer(image)
|
131 |
+
res = torch.argmax(res)
|
132 |
+
cname = ModelClass.get_class(res)
|
133 |
+
st.write(f'{cname}')
|
134 |
+
|
135 |
+
|
136 |
+
prediction = predict(image)
|
137 |
+
right_column.subheader('Results')
|
138 |
+
for class_name, class_probability in prediction.items():
|
139 |
+
right_column.write(f'{class_name}: {class_probability:.2%}')
|
140 |
+
right_column.progress(class_probability)
|
141 |
+
|
142 |
+
|
143 |
+
st.markdown("---")
|
144 |
+
st.markdown("Built by [Shamim Ahamed](https://your-portfolio-website.com/). Data provided by [Kaggle](https://www.kaggle.com/c/)")
|
145 |
|
146 |
+
|
147 |
+
app()
|
classifier_weights.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac67f64705d4109573e06c65925367ee1697f4359ebf1a227a20ba69d3e7a1ba
|
3 |
+
size 96504105
|
inputs/Image_102.jpg
ADDED
inputs/Image_129.jpg
ADDED
inputs/Image_131.jpg
ADDED
inputs/Image_18.jpg
ADDED
inputs/Image_180.jpg
ADDED
inputs/Image_183.jpg
ADDED
inputs/Image_202.jpg
ADDED
inputs/Image_207.jpg
ADDED
inputs/Image_208.jpg
ADDED
inputs/Image_209.jpg
ADDED
inputs/Image_213.jpg
ADDED
inputs/Image_222.jpg
ADDED
inputs/Image_267.jpg
ADDED
inputs/Image_28.jpg
ADDED
inputs/Image_33.jpg
ADDED
inputs/Image_35.jpg
ADDED
inputs/Image_409.jpg
ADDED
inputs/Image_41.jpg
ADDED
inputs/Image_416.jpg
ADDED
inputs/Image_417.jpg
ADDED
inputs/Image_42.jpg
ADDED
inputs/Image_516.jpg
ADDED
inputs/Image_527.jpg
ADDED
inputs/Image_58.jpg
ADDED
inputs/Image_7.jpg
ADDED