Spaces:
Runtime error
Runtime error
Cawinchan
commited on
Commit
•
6781cf0
1
Parent(s):
c1d90be
requirements
Browse files- app.py +2 -2
- requirements.txt +1 -1
- src/app.py +0 -156
- src/main.py +0 -156
app.py
CHANGED
@@ -67,8 +67,8 @@ with st.spinner('Model is being loaded..'):
|
|
67 |
PATH = Path(__file__).resolve().parent.parent/'models'/'efficientnet_10_25_full.pt'
|
68 |
st.write(PATH)
|
69 |
# Use cuda to enable gpu usage for pytorch
|
70 |
-
|
71 |
-
device = torch.device("cpu")
|
72 |
# if MODEL_NAME in 'efficientnet':
|
73 |
# efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b4', pretrained=True)
|
74 |
|
|
|
67 |
PATH = Path(__file__).resolve().parent.parent/'models'/'efficientnet_10_25_full.pt'
|
68 |
st.write(PATH)
|
69 |
# Use cuda to enable gpu usage for pytorch
|
70 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
71 |
+
# device = torch.device("cpu")
|
72 |
# if MODEL_NAME in 'efficientnet':
|
73 |
# efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b4', pretrained=True)
|
74 |
|
requirements.txt
CHANGED
@@ -5,4 +5,4 @@ numpy==1.21.6
|
|
5 |
pathlib
|
6 |
pathlib2==2.3.5
|
7 |
torch
|
8 |
-
pandas
|
|
|
5 |
pathlib
|
6 |
pathlib2==2.3.5
|
7 |
torch
|
8 |
+
pandas
|
src/app.py
DELETED
@@ -1,156 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import torchvision
|
3 |
-
from torchvision import transforms, datasets
|
4 |
-
from PIL import Image, ImageOps
|
5 |
-
import numpy as np
|
6 |
-
from pathlib import Path
|
7 |
-
import torch
|
8 |
-
import torchvision.models as models
|
9 |
-
from torchvision import transforms, datasets
|
10 |
-
from torch.utils.data import DataLoader
|
11 |
-
import torch.nn as nn
|
12 |
-
from torch import optim
|
13 |
-
import pandas as pd
|
14 |
-
|
15 |
-
label_map = {
|
16 |
-
0: "ALB",
|
17 |
-
1: "BET",
|
18 |
-
2: "DOL",
|
19 |
-
3: "LAG",
|
20 |
-
4: "NoF",
|
21 |
-
5: "OTHER",
|
22 |
-
6: "SHARK",
|
23 |
-
7: "YFT",
|
24 |
-
}
|
25 |
-
|
26 |
-
label_list= [
|
27 |
-
"ALB",
|
28 |
-
"BET",
|
29 |
-
"DOL",
|
30 |
-
"LAG",
|
31 |
-
"NoF",
|
32 |
-
"OTHER",
|
33 |
-
"SHARK",
|
34 |
-
"YFT"
|
35 |
-
]
|
36 |
-
|
37 |
-
predicted_to_actual_dict = {
|
38 |
-
"ALB" : 'Albacore Tuna',
|
39 |
-
"BET" : 'Bigeye Tuna',
|
40 |
-
"DOL" : 'Dolphinfish, Mahi Mahi',
|
41 |
-
"LAG" : 'Opah, Moonfish',
|
42 |
-
"NoF" : 'No Fish',
|
43 |
-
"OTHER" : 'Fish present but not in target categories',
|
44 |
-
"SHARK" : 'Shark, including Silky & Shortfin Mako',
|
45 |
-
"YFT" : 'Yellowfin Tuna'
|
46 |
-
}
|
47 |
-
|
48 |
-
fish_to_wiki = {
|
49 |
-
0: "https://en.wikipedia.org/wiki/Albacore",
|
50 |
-
1: "https://en.wikipedia.org/wiki/Bigeye_tuna",
|
51 |
-
2: "https://en.wikipedia.org/wiki/Mahi-mahi",
|
52 |
-
3: "https://en.wikipedia.org/wiki/Opah",
|
53 |
-
4: "https://en.wikipedia.org/wiki/Fish",
|
54 |
-
5: "https://en.wikipedia.org/wiki/Fish",
|
55 |
-
6: "https://en.wikipedia.org/wiki/Shark",
|
56 |
-
7: "https://en.wikipedia.org/wiki/Yellowfin_tuna",
|
57 |
-
}
|
58 |
-
|
59 |
-
MODEL_NAME = 'efficientnet'
|
60 |
-
|
61 |
-
@st.cache()
|
62 |
-
def augment_model(efficientnet):
|
63 |
-
efficientnet.classifier[-1] = nn.Linear(in_features=1792, out_features=len(label_map), bias=True)
|
64 |
-
return efficientnet
|
65 |
-
|
66 |
-
with st.spinner('Model is being loaded..'):
|
67 |
-
PATH = Path(__file__).resolve().parent.parent/'models'/'efficientnet_10_25_full.pt'
|
68 |
-
st.write(PATH)
|
69 |
-
# Use cuda to enable gpu usage for pytorch
|
70 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
71 |
-
device = torch.device("cpu")
|
72 |
-
# if MODEL_NAME in 'efficientnet':
|
73 |
-
# efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b4', pretrained=True)
|
74 |
-
|
75 |
-
# model_ft = augment_model(efficientnet)
|
76 |
-
# model_ft.load_state_dict(torch.load(PATH,map_location=device))
|
77 |
-
|
78 |
-
# else:
|
79 |
-
st.write(torch.cuda.is_available())
|
80 |
-
model_ft = torch.load(PATH,map_location=device)
|
81 |
-
|
82 |
-
st.write("""
|
83 |
-
# Endangered Fish Classification
|
84 |
-
"""
|
85 |
-
)
|
86 |
-
st.write('Nearly half of the world depends on seafood for their main source of protein. In the Western and Central Pacific, where 60% of the world’s tuna is caught, illegal, unreported, and unregulated fishing practices are threatening marine ecosystems, global seafood supplies and local livelihoods. The Nature Conservancy is working with local, regional and global partners to preserve this fishery for the future.')
|
87 |
-
st.write('Currently, the Conservancy is looking to the future by using cameras to dramatically scale the monitoring of fishing activities to fill critical science and compliance monitoring data gaps. Our trained model helps to identify when target endangered species have been caught by fishermen.')
|
88 |
-
file = st.file_uploader("Please upload your fish image", type=["jpg","png"])
|
89 |
-
|
90 |
-
|
91 |
-
st.set_option('deprecation.showfileUploaderEncoding', False)
|
92 |
-
|
93 |
-
@st.cache()
|
94 |
-
def import_and_predict(image_data: Image.Image, model, k: int, index_to_label_dict: dict)-> list:
|
95 |
-
|
96 |
-
if MODEL_NAME in 'vgg':
|
97 |
-
transform = transforms.Compose([
|
98 |
-
transforms.Resize((224, 224)),
|
99 |
-
transforms.ToTensor(),
|
100 |
-
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
|
101 |
-
)
|
102 |
-
if MODEL_NAME in 'resnet' or MODEL_NAME in 'alexnet' or MODEL_NAME in 'efficientnet':
|
103 |
-
transform = transforms.Compose([
|
104 |
-
transforms.Resize(256),
|
105 |
-
transforms.CenterCrop(224),
|
106 |
-
transforms.ToTensor(),
|
107 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]
|
108 |
-
)
|
109 |
-
|
110 |
-
actual_img = transform(image_data).to(device)
|
111 |
-
actual_img = actual_img.unsqueeze(0) # add one dimension to the front to account for batch_size
|
112 |
-
|
113 |
-
formatted_predictions = model(actual_img)
|
114 |
-
return formatted_predictions
|
115 |
-
|
116 |
-
if file is None:
|
117 |
-
pass
|
118 |
-
else:
|
119 |
-
image = Image.open(file)
|
120 |
-
|
121 |
-
st.image(image, use_column_width=True)
|
122 |
-
|
123 |
-
model_ft.eval()
|
124 |
-
predictions = import_and_predict(image, model_ft, k = 3, index_to_label_dict = label_map)
|
125 |
-
|
126 |
-
predicted_fish = label_map[int(torch.argmax(predictions))]
|
127 |
-
normalised_list = torch.nn.functional.softmax(predictions, dim = 1)
|
128 |
-
values, indices = torch.topk(normalised_list, 3)
|
129 |
-
|
130 |
-
st.title('The predicted fish is: ' + predicted_to_actual_dict[predicted_fish])
|
131 |
-
|
132 |
-
st.title('Here are the three most likely fish species(click for more info!)')
|
133 |
-
df = pd.DataFrame(data=np.zeros((3, 2)),
|
134 |
-
columns=['Species', 'Confidence Level'],
|
135 |
-
index=np.linspace(1, 3, 3, dtype=int))
|
136 |
-
|
137 |
-
# print(values.detach().numpy()[0][1])
|
138 |
-
|
139 |
-
for count, i in enumerate(values.detach().numpy()[0]):
|
140 |
-
x = int(indices.detach().numpy()[0][count])
|
141 |
-
df.iloc[count, 0] = f'<a href="{fish_to_wiki[x]}" target="_blank">{predicted_to_actual_dict[label_map[x]].title()}</a>'
|
142 |
-
df.iloc[count, 1] = np.format_float_positional(i, precision=8)
|
143 |
-
|
144 |
-
st.write(df.to_html(escape=False), unsafe_allow_html=True)
|
145 |
-
if predicted_fish not in ['OTHER', 'Nof']:
|
146 |
-
|
147 |
-
PATH_fish = Path(__file__).resolve().parent/'data'/'fishes_ref'/ (predicted_fish + '.jpg')
|
148 |
-
st.title('Here is a sample image of ' + predicted_to_actual_dict[predicted_fish])
|
149 |
-
reference_image = Image.open(PATH_fish)
|
150 |
-
st.image(reference_image)
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/main.py
DELETED
@@ -1,156 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import torchvision
|
3 |
-
from torchvision import transforms, datasets
|
4 |
-
from PIL import Image, ImageOps
|
5 |
-
import numpy as np
|
6 |
-
from pathlib import Path
|
7 |
-
import torch
|
8 |
-
import torchvision.models as models
|
9 |
-
from torchvision import transforms, datasets
|
10 |
-
from torch.utils.data import DataLoader
|
11 |
-
import torch.nn as nn
|
12 |
-
from torch import optim
|
13 |
-
import pandas as pd
|
14 |
-
|
15 |
-
label_map = {
|
16 |
-
0: "ALB",
|
17 |
-
1: "BET",
|
18 |
-
2: "DOL",
|
19 |
-
3: "LAG",
|
20 |
-
4: "NoF",
|
21 |
-
5: "OTHER",
|
22 |
-
6: "SHARK",
|
23 |
-
7: "YFT",
|
24 |
-
}
|
25 |
-
|
26 |
-
label_list= [
|
27 |
-
"ALB",
|
28 |
-
"BET",
|
29 |
-
"DOL",
|
30 |
-
"LAG",
|
31 |
-
"NoF",
|
32 |
-
"OTHER",
|
33 |
-
"SHARK",
|
34 |
-
"YFT"
|
35 |
-
]
|
36 |
-
|
37 |
-
predicted_to_actual_dict = {
|
38 |
-
"ALB" : 'Albacore Tuna',
|
39 |
-
"BET" : 'Bigeye Tuna',
|
40 |
-
"DOL" : 'Dolphinfish, Mahi Mahi',
|
41 |
-
"LAG" : 'Opah, Moonfish',
|
42 |
-
"NoF" : 'No Fish',
|
43 |
-
"OTHER" : 'Fish present but not in target categories',
|
44 |
-
"SHARK" : 'Shark, including Silky & Shortfin Mako',
|
45 |
-
"YFT" : 'Yellowfin Tuna'
|
46 |
-
}
|
47 |
-
|
48 |
-
fish_to_wiki = {
|
49 |
-
0: "https://en.wikipedia.org/wiki/Albacore",
|
50 |
-
1: "https://en.wikipedia.org/wiki/Bigeye_tuna",
|
51 |
-
2: "https://en.wikipedia.org/wiki/Mahi-mahi",
|
52 |
-
3: "https://en.wikipedia.org/wiki/Opah",
|
53 |
-
4: "https://en.wikipedia.org/wiki/Fish",
|
54 |
-
5: "https://en.wikipedia.org/wiki/Fish",
|
55 |
-
6: "https://en.wikipedia.org/wiki/Shark",
|
56 |
-
7: "https://en.wikipedia.org/wiki/Yellowfin_tuna",
|
57 |
-
}
|
58 |
-
|
59 |
-
MODEL_NAME = 'efficientnet'
|
60 |
-
|
61 |
-
@st.cache()
|
62 |
-
def augment_model(efficientnet):
|
63 |
-
efficientnet.classifier[-1] = nn.Linear(in_features=1792, out_features=len(label_map), bias=True)
|
64 |
-
return efficientnet
|
65 |
-
|
66 |
-
with st.spinner('Model is being loaded..'):
|
67 |
-
PATH = Path(__file__).resolve().parent.parent/'models'/'efficientnet_10_25_full.pt'
|
68 |
-
st.write(PATH)
|
69 |
-
# Use cuda to enable gpu usage for pytorch
|
70 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
71 |
-
device = torch.device("cpu")
|
72 |
-
# if MODEL_NAME in 'efficientnet':
|
73 |
-
# efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b4', pretrained=True)
|
74 |
-
|
75 |
-
# model_ft = augment_model(efficientnet)
|
76 |
-
# model_ft.load_state_dict(torch.load(PATH,map_location=device))
|
77 |
-
|
78 |
-
# else:
|
79 |
-
st.write(torch.cuda.is_available())
|
80 |
-
model_ft = torch.load(PATH,map_location=device)
|
81 |
-
|
82 |
-
st.write("""
|
83 |
-
# Endangered Fish Classification
|
84 |
-
"""
|
85 |
-
)
|
86 |
-
st.write('Nearly half of the world depends on seafood for their main source of protein. In the Western and Central Pacific, where 60% of the world’s tuna is caught, illegal, unreported, and unregulated fishing practices are threatening marine ecosystems, global seafood supplies and local livelihoods. The Nature Conservancy is working with local, regional and global partners to preserve this fishery for the future.')
|
87 |
-
st.write('Currently, the Conservancy is looking to the future by using cameras to dramatically scale the monitoring of fishing activities to fill critical science and compliance monitoring data gaps. Our trained model helps to identify when target endangered species have been caught by fishermen.')
|
88 |
-
file = st.file_uploader("Please upload your fish image", type=["jpg","png"])
|
89 |
-
|
90 |
-
|
91 |
-
st.set_option('deprecation.showfileUploaderEncoding', False)
|
92 |
-
|
93 |
-
@st.cache()
|
94 |
-
def import_and_predict(image_data: Image.Image, model, k: int, index_to_label_dict: dict)-> list:
|
95 |
-
|
96 |
-
if MODEL_NAME in 'vgg':
|
97 |
-
transform = transforms.Compose([
|
98 |
-
transforms.Resize((224, 224)),
|
99 |
-
transforms.ToTensor(),
|
100 |
-
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
|
101 |
-
)
|
102 |
-
if MODEL_NAME in 'resnet' or MODEL_NAME in 'alexnet' or MODEL_NAME in 'efficientnet':
|
103 |
-
transform = transforms.Compose([
|
104 |
-
transforms.Resize(256),
|
105 |
-
transforms.CenterCrop(224),
|
106 |
-
transforms.ToTensor(),
|
107 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]
|
108 |
-
)
|
109 |
-
|
110 |
-
actual_img = transform(image_data).to(device)
|
111 |
-
actual_img = actual_img.unsqueeze(0) # add one dimension to the front to account for batch_size
|
112 |
-
|
113 |
-
formatted_predictions = model(actual_img)
|
114 |
-
return formatted_predictions
|
115 |
-
|
116 |
-
if file is None:
|
117 |
-
pass
|
118 |
-
else:
|
119 |
-
image = Image.open(file)
|
120 |
-
|
121 |
-
st.image(image, use_column_width=True)
|
122 |
-
|
123 |
-
model_ft.eval()
|
124 |
-
predictions = import_and_predict(image, model_ft, k = 3, index_to_label_dict = label_map)
|
125 |
-
|
126 |
-
predicted_fish = label_map[int(torch.argmax(predictions))]
|
127 |
-
normalised_list = torch.nn.functional.softmax(predictions, dim = 1)
|
128 |
-
values, indices = torch.topk(normalised_list, 3)
|
129 |
-
|
130 |
-
st.title('The predicted fish is: ' + predicted_to_actual_dict[predicted_fish])
|
131 |
-
|
132 |
-
st.title('Here are the three most likely fish species(click for more info!)')
|
133 |
-
df = pd.DataFrame(data=np.zeros((3, 2)),
|
134 |
-
columns=['Species', 'Confidence Level'],
|
135 |
-
index=np.linspace(1, 3, 3, dtype=int))
|
136 |
-
|
137 |
-
# print(values.detach().numpy()[0][1])
|
138 |
-
|
139 |
-
for count, i in enumerate(values.detach().numpy()[0]):
|
140 |
-
x = int(indices.detach().numpy()[0][count])
|
141 |
-
df.iloc[count, 0] = f'<a href="{fish_to_wiki[x]}" target="_blank">{predicted_to_actual_dict[label_map[x]].title()}</a>'
|
142 |
-
df.iloc[count, 1] = np.format_float_positional(i, precision=8)
|
143 |
-
|
144 |
-
st.write(df.to_html(escape=False), unsafe_allow_html=True)
|
145 |
-
if predicted_fish not in ['OTHER', 'Nof']:
|
146 |
-
|
147 |
-
PATH_fish = Path(__file__).resolve().parent/'data'/'fishes_ref'/ (predicted_fish + '.jpg')
|
148 |
-
st.title('Here is a sample image of ' + predicted_to_actual_dict[predicted_fish])
|
149 |
-
reference_image = Image.open(PATH_fish)
|
150 |
-
st.image(reference_image)
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|