DataRaptor commited on
Commit
1556762
·
1 Parent(s): 0e7de7e

Upload 28 files

Browse files
ModelClass.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn, optim
4
+ from torchvision import transforms, models
5
+ #from torch_snippets import *
6
+ #from torch.utils.data import DataLoader, Dataset
7
+ #from torchsummary import summary
8
+
9
+ #import seaborn as sns
10
+ #import matplotlib.pyplot as plt
11
+ #from sklearn.model_selection import train_test_split
12
+ from PIL import Image
13
+ #import numpy as np
14
+ #import cv2
15
+ #from glob import glob
16
+ #import pandas as pd
17
+ import numpy as np
18
+
19
+ #device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+
21
+
22
+
23
+ class ActionClassifier(nn.Module):
24
+ def __init__(self, ntargets):
25
+ super().__init__()
26
+ resnet = models.resnet50(pretrained=True, progress=True)
27
+ modules = list(resnet.children())[:-1] # delete last layer
28
+ self.resnet = nn.Sequential(*modules)
29
+ for param in self.resnet.parameters():
30
+ param.requires_grad = False
31
+ self.fc = nn.Sequential(
32
+ nn.Flatten(),
33
+ nn.BatchNorm1d(resnet.fc.in_features),
34
+ nn.Dropout(0.2),
35
+ nn.Linear(resnet.fc.in_features, 256),
36
+ nn.ReLU(),
37
+ nn.BatchNorm1d(256),
38
+ nn.Dropout(0.2),
39
+ nn.Linear(256, ntargets)
40
+ )
41
+
42
+ def forward(self, x):
43
+ x = self.resnet(x)
44
+ x = self.fc(x)
45
+ return x
46
+
47
+
48
+
49
+ def get_transform():
50
+ transform = transforms.Compose([
51
+ transforms.Resize([224, 244]),
52
+ transforms.ToTensor(),
53
+ # std multiply by 255 to convert img of [0, 255]
54
+ # to img of [0, 1]
55
+ transforms.Normalize((0.485, 0.456, 0.406),
56
+ (0.229*255, 0.224*255, 0.225*255))]
57
+ )
58
+ return transform
59
+
60
+
61
+ def get_model():
62
+ model = ActionClassifier(15)
63
+ model.load_state_dict(torch.load('./classifier_weights.pth', map_location=torch.device('cpu')))
64
+ return model
65
+
66
+
67
+ def get_class(index):
68
+ ind2cat = [
69
+ 'calling',
70
+ 'clapping',
71
+ 'cycling',
72
+ 'dancing',
73
+ 'drinking',
74
+ 'eating',
75
+ 'fighting',
76
+ 'hugging',
77
+ 'laughing',
78
+ 'listening_to_music',
79
+ 'running',
80
+ 'sitting',
81
+ 'sleeping',
82
+ 'texting',
83
+ 'using_laptop'
84
+ ]
85
+ return ind2cat[index]
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+ # img = Image.open('./inputs/Image_102.jpg').convert('RGB')
95
+
96
+ # #print(transform(img))
97
+
98
+ # img = transform(img)
99
+
100
+ # img = img.unsqueeze(dim=0)
101
+ # print(img.shape)
102
+
103
+
104
+
105
+
106
+
107
+
108
+ # model.eval()
109
+ # with torch.no_grad():
110
+ # out = model(img)
111
+ # out = nn.Softmax()(out).squeeze()
112
+ # print(out.shape)
113
+ # res = torch.argmax(out)
114
+
115
+ # print(ind2cat[res])
116
+
117
+
118
+
119
+
app.py CHANGED
@@ -2,6 +2,58 @@ import streamlit as st
2
  import numpy as np
3
  from PIL import Image
4
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  # fix sidebar
@@ -19,6 +71,10 @@ st.markdown("""
19
  margin: 0.125rem 2rem;
20
  background-color: rgb(181 197 227 / 18%) !important;
21
  }
 
 
 
 
22
  </style>
23
  """, unsafe_allow_html=True
24
  )
@@ -29,12 +85,10 @@ hide_st_style = """
29
  header {visibility: hidden;}
30
  </style>
31
  """
32
- st.markdown(hide_st_style, unsafe_allow_html=True)
33
 
34
 
35
 
36
-
37
- # Function to load and predict image
38
  def predict(image):
39
  # Dummy prediction
40
  classes = ['cat', 'dog']
@@ -42,53 +96,52 @@ def predict(image):
42
  prediction /= np.sum(prediction)
43
  return dict(zip(classes, prediction))
44
 
45
- # Define app layout
46
- #st.set_page_config(page_title='Image Classification App', page_icon=':camera:', layout='wide')
47
- st.title('HappyWhale')
48
- st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai/<username>/<project_name>?workspace=user-<username>)")
49
-
50
- st.markdown('This project aims to identify whales and dolphins by their unique characteristics. It can help researchers understand their behavior, population dynamics, and migration patterns. This project can aid researchers in identifying these marine mammals, providing valuable data for conservation efforts. [[Source Code]](https://kaggle.com/)')
51
-
52
-
53
- # Add file uploader
54
- uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
55
-
56
- # Add test image selector
57
- test_images = {
58
- 'Cat': 'https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg',
59
- 'Dog': 'https://upload.wikimedia.org/wikipedia/commons/thumb/6/6e/Golde33443.jpg/1200px-Golde33443.jpg',
60
- 'Bird': 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/0a/Scarlet_Tanager_-_male_%28cropped%29.jpg/1200px-Scarlet_Tanager_-_male_%28cropped%29.jpg'
61
- }
62
- test_image = st.selectbox('Or choose a test image', list(test_images.keys()))
63
-
64
- st.subheader('Selected Image')
65
- # Define layout of app
66
- left_column, right_column = st.columns([1, 2.5], gap="medium")
67
- with left_column:
68
 
69
- if uploaded_file is not None:
70
- image = Image.open(uploaded_file)
71
- st.image(image, use_column_width=True)
72
- else:
73
- image_url = test_images[test_image]
74
- image = Image.open(requests.get(image_url, stream=True).raw)
75
- st.image(image, use_column_width=True)
76
-
 
77
 
78
- if st.button('✨ Get prediction from AI', type='primary'):
79
- spacer = st.empty()
 
 
 
80
 
 
 
 
 
 
 
 
 
81
 
82
- prediction = predict(image)
83
- right_column.subheader('Results')
84
- for class_name, class_probability in prediction.items():
85
- right_column.write(f'{class_name}: {class_probability:.2%}')
86
- right_column.progress(class_probability)
87
-
88
-
89
- # Display a footer with links and credits
90
- st.markdown("---")
91
- st.markdown("Built by [Shamim Ahamed](https://your-portfolio-website.com/). Data provided by [Kaggle](https://www.kaggle.com/c/)")
92
- #st.markdown("Data provided by [The Feedback Prize - ELLIPSE Corpus Scoring Challenge on Kaggle](https://www.kaggle.com/c/feedbackprize-ellipse-corpus-scoring-challenge)")
 
 
 
 
 
 
 
 
93
 
94
-
 
 
2
  import numpy as np
3
  from PIL import Image
4
  import requests
5
+ import ModelClass
6
+ from glob import glob
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ @st.cache_resource
11
+ def load_model():
12
+ return ModelClass.get_model()
13
+
14
+ @st.cache_data
15
+ def get_images():
16
+ l = glob('./inputs/*')
17
+ l = {i.split('/')[-1]: i for i in l}
18
+ return l
19
+
20
+
21
+ def infer(img):
22
+ image = img.convert('RGB')
23
+ image = ModelClass.get_transform()(image)
24
+ image = image.unsqueeze(dim=0)
25
+
26
+ model = load_model()
27
+ model.eval()
28
+ with torch.no_grad():
29
+ out = model(image)
30
+ out = nn.Softmax()(out).squeeze()
31
+ return out
32
+
33
+
34
+
35
+
36
+ st.set_page_config(
37
+ page_title="Whale Identification",
38
+ page_icon="🧊",
39
+ layout="centered",
40
+ initial_sidebar_state="expanded",
41
+ menu_items={
42
+ 'Get Help': 'https://www.extremelycoolapp.com/help',
43
+ 'Report a bug': "https://www.extremelycoolapp.com/bug",
44
+ 'About': """
45
+ # This is a header. This is an *extremely* cool app!
46
+ How how are you doin.
47
+
48
+ ---
49
+ I am fine
50
+
51
+
52
+ <style>
53
+ </style>
54
+ """
55
+ }
56
+ )
57
 
58
 
59
  # fix sidebar
 
71
  margin: 0.125rem 2rem;
72
  background-color: rgb(181 197 227 / 18%) !important;
73
  }
74
+ .css-1y4p8pa {
75
+ padding: 3rem 1rem 10rem;
76
+ max-width: 58rem;
77
+ }
78
  </style>
79
  """, unsafe_allow_html=True
80
  )
 
85
  header {visibility: hidden;}
86
  </style>
87
  """
88
+ #st.markdown(hide_st_style, unsafe_allow_html=True)
89
 
90
 
91
 
 
 
92
  def predict(image):
93
  # Dummy prediction
94
  classes = ['cat', 'dog']
 
96
  prediction /= np.sum(prediction)
97
  return dict(zip(classes, prediction))
98
 
99
+ def app():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ st.title('ActionNet')
102
+ st.markdown("[![View in W&B](https://img.shields.io/badge/View%20in-W%26B-blue)](https://wandb.ai/<username>/<project_name>?workspace=user-<username>)")
103
+ st.markdown('This project aims to identify whales and dolphins by their unique characteristics. It can help researchers understand their behavior, population dynamics, and migration patterns. This project can aid researchers in identifying these marine mammals, providing valuable data for conservation efforts. [[Source Code]](https://kaggle.com/)')
104
+
105
+
106
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
107
+
108
+ test_images = get_images()
109
+ test_image = st.selectbox('Or choose a test image', list(test_images.keys()))
110
 
111
+
112
+ st.subheader('Selected Image')
113
+
114
+ left_column, right_column = st.columns([1.5, 2.5], gap="medium")
115
+ with left_column:
116
 
117
+ if uploaded_file is not None:
118
+ image = Image.open(uploaded_file)
119
+ st.image(image, use_column_width=True)
120
+ else:
121
+ image_url = test_images[test_image]
122
+ image = Image.open(image_url)
123
+ st.image(image, use_column_width=True)
124
+
125
 
126
+ if st.button('✨ Get prediction from AI', type='primary'):
127
+ spacer = st.empty()
128
+
129
+
130
+ res = infer(image)
131
+ res = torch.argmax(res)
132
+ cname = ModelClass.get_class(res)
133
+ st.write(f'{cname}')
134
+
135
+
136
+ prediction = predict(image)
137
+ right_column.subheader('Results')
138
+ for class_name, class_probability in prediction.items():
139
+ right_column.write(f'{class_name}: {class_probability:.2%}')
140
+ right_column.progress(class_probability)
141
+
142
+
143
+ st.markdown("---")
144
+ st.markdown("Built by [Shamim Ahamed](https://your-portfolio-website.com/). Data provided by [Kaggle](https://www.kaggle.com/c/)")
145
 
146
+
147
+ app()
classifier_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac67f64705d4109573e06c65925367ee1697f4359ebf1a227a20ba69d3e7a1ba
3
+ size 96504105
inputs/Image_102.jpg ADDED
inputs/Image_129.jpg ADDED
inputs/Image_131.jpg ADDED
inputs/Image_18.jpg ADDED
inputs/Image_180.jpg ADDED
inputs/Image_183.jpg ADDED
inputs/Image_202.jpg ADDED
inputs/Image_207.jpg ADDED
inputs/Image_208.jpg ADDED
inputs/Image_209.jpg ADDED
inputs/Image_213.jpg ADDED
inputs/Image_222.jpg ADDED
inputs/Image_267.jpg ADDED
inputs/Image_28.jpg ADDED
inputs/Image_33.jpg ADDED
inputs/Image_35.jpg ADDED
inputs/Image_409.jpg ADDED
inputs/Image_41.jpg ADDED
inputs/Image_416.jpg ADDED
inputs/Image_417.jpg ADDED
inputs/Image_42.jpg ADDED
inputs/Image_516.jpg ADDED
inputs/Image_527.jpg ADDED
inputs/Image_58.jpg ADDED
inputs/Image_7.jpg ADDED