Cawinchan commited on
Commit
6781cf0
1 Parent(s): c1d90be

requirements

Browse files
Files changed (4) hide show
  1. app.py +2 -2
  2. requirements.txt +1 -1
  3. src/app.py +0 -156
  4. src/main.py +0 -156
app.py CHANGED
@@ -67,8 +67,8 @@ with st.spinner('Model is being loaded..'):
67
  PATH = Path(__file__).resolve().parent.parent/'models'/'efficientnet_10_25_full.pt'
68
  st.write(PATH)
69
  # Use cuda to enable gpu usage for pytorch
70
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
- device = torch.device("cpu")
72
  # if MODEL_NAME in 'efficientnet':
73
  # efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b4', pretrained=True)
74
 
 
67
  PATH = Path(__file__).resolve().parent.parent/'models'/'efficientnet_10_25_full.pt'
68
  st.write(PATH)
69
  # Use cuda to enable gpu usage for pytorch
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ # device = torch.device("cpu")
72
  # if MODEL_NAME in 'efficientnet':
73
  # efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b4', pretrained=True)
74
 
requirements.txt CHANGED
@@ -5,4 +5,4 @@ numpy==1.21.6
5
  pathlib
6
  pathlib2==2.3.5
7
  torch
8
- pandas==0.25.1
 
5
  pathlib
6
  pathlib2==2.3.5
7
  torch
8
+ pandas
src/app.py DELETED
@@ -1,156 +0,0 @@
1
- import streamlit as st
2
- import torchvision
3
- from torchvision import transforms, datasets
4
- from PIL import Image, ImageOps
5
- import numpy as np
6
- from pathlib import Path
7
- import torch
8
- import torchvision.models as models
9
- from torchvision import transforms, datasets
10
- from torch.utils.data import DataLoader
11
- import torch.nn as nn
12
- from torch import optim
13
- import pandas as pd
14
-
15
- label_map = {
16
- 0: "ALB",
17
- 1: "BET",
18
- 2: "DOL",
19
- 3: "LAG",
20
- 4: "NoF",
21
- 5: "OTHER",
22
- 6: "SHARK",
23
- 7: "YFT",
24
- }
25
-
26
- label_list= [
27
- "ALB",
28
- "BET",
29
- "DOL",
30
- "LAG",
31
- "NoF",
32
- "OTHER",
33
- "SHARK",
34
- "YFT"
35
- ]
36
-
37
- predicted_to_actual_dict = {
38
- "ALB" : 'Albacore Tuna',
39
- "BET" : 'Bigeye Tuna',
40
- "DOL" : 'Dolphinfish, Mahi Mahi',
41
- "LAG" : 'Opah, Moonfish',
42
- "NoF" : 'No Fish',
43
- "OTHER" : 'Fish present but not in target categories',
44
- "SHARK" : 'Shark, including Silky & Shortfin Mako',
45
- "YFT" : 'Yellowfin Tuna'
46
- }
47
-
48
- fish_to_wiki = {
49
- 0: "https://en.wikipedia.org/wiki/Albacore",
50
- 1: "https://en.wikipedia.org/wiki/Bigeye_tuna",
51
- 2: "https://en.wikipedia.org/wiki/Mahi-mahi",
52
- 3: "https://en.wikipedia.org/wiki/Opah",
53
- 4: "https://en.wikipedia.org/wiki/Fish",
54
- 5: "https://en.wikipedia.org/wiki/Fish",
55
- 6: "https://en.wikipedia.org/wiki/Shark",
56
- 7: "https://en.wikipedia.org/wiki/Yellowfin_tuna",
57
- }
58
-
59
- MODEL_NAME = 'efficientnet'
60
-
61
- @st.cache()
62
- def augment_model(efficientnet):
63
- efficientnet.classifier[-1] = nn.Linear(in_features=1792, out_features=len(label_map), bias=True)
64
- return efficientnet
65
-
66
- with st.spinner('Model is being loaded..'):
67
- PATH = Path(__file__).resolve().parent.parent/'models'/'efficientnet_10_25_full.pt'
68
- st.write(PATH)
69
- # Use cuda to enable gpu usage for pytorch
70
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
- device = torch.device("cpu")
72
- # if MODEL_NAME in 'efficientnet':
73
- # efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b4', pretrained=True)
74
-
75
- # model_ft = augment_model(efficientnet)
76
- # model_ft.load_state_dict(torch.load(PATH,map_location=device))
77
-
78
- # else:
79
- st.write(torch.cuda.is_available())
80
- model_ft = torch.load(PATH,map_location=device)
81
-
82
- st.write("""
83
- # Endangered Fish Classification
84
- """
85
- )
86
- st.write('Nearly half of the world depends on seafood for their main source of protein. In the Western and Central Pacific, where 60% of the world’s tuna is caught, illegal, unreported, and unregulated fishing practices are threatening marine ecosystems, global seafood supplies and local livelihoods. The Nature Conservancy is working with local, regional and global partners to preserve this fishery for the future.')
87
- st.write('Currently, the Conservancy is looking to the future by using cameras to dramatically scale the monitoring of fishing activities to fill critical science and compliance monitoring data gaps. Our trained model helps to identify when target endangered species have been caught by fishermen.')
88
- file = st.file_uploader("Please upload your fish image", type=["jpg","png"])
89
-
90
-
91
- st.set_option('deprecation.showfileUploaderEncoding', False)
92
-
93
- @st.cache()
94
- def import_and_predict(image_data: Image.Image, model, k: int, index_to_label_dict: dict)-> list:
95
-
96
- if MODEL_NAME in 'vgg':
97
- transform = transforms.Compose([
98
- transforms.Resize((224, 224)),
99
- transforms.ToTensor(),
100
- transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
101
- )
102
- if MODEL_NAME in 'resnet' or MODEL_NAME in 'alexnet' or MODEL_NAME in 'efficientnet':
103
- transform = transforms.Compose([
104
- transforms.Resize(256),
105
- transforms.CenterCrop(224),
106
- transforms.ToTensor(),
107
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]
108
- )
109
-
110
- actual_img = transform(image_data).to(device)
111
- actual_img = actual_img.unsqueeze(0) # add one dimension to the front to account for batch_size
112
-
113
- formatted_predictions = model(actual_img)
114
- return formatted_predictions
115
-
116
- if file is None:
117
- pass
118
- else:
119
- image = Image.open(file)
120
-
121
- st.image(image, use_column_width=True)
122
-
123
- model_ft.eval()
124
- predictions = import_and_predict(image, model_ft, k = 3, index_to_label_dict = label_map)
125
-
126
- predicted_fish = label_map[int(torch.argmax(predictions))]
127
- normalised_list = torch.nn.functional.softmax(predictions, dim = 1)
128
- values, indices = torch.topk(normalised_list, 3)
129
-
130
- st.title('The predicted fish is: ' + predicted_to_actual_dict[predicted_fish])
131
-
132
- st.title('Here are the three most likely fish species(click for more info!)')
133
- df = pd.DataFrame(data=np.zeros((3, 2)),
134
- columns=['Species', 'Confidence Level'],
135
- index=np.linspace(1, 3, 3, dtype=int))
136
-
137
- # print(values.detach().numpy()[0][1])
138
-
139
- for count, i in enumerate(values.detach().numpy()[0]):
140
- x = int(indices.detach().numpy()[0][count])
141
- df.iloc[count, 0] = f'<a href="{fish_to_wiki[x]}" target="_blank">{predicted_to_actual_dict[label_map[x]].title()}</a>'
142
- df.iloc[count, 1] = np.format_float_positional(i, precision=8)
143
-
144
- st.write(df.to_html(escape=False), unsafe_allow_html=True)
145
- if predicted_fish not in ['OTHER', 'Nof']:
146
-
147
- PATH_fish = Path(__file__).resolve().parent/'data'/'fishes_ref'/ (predicted_fish + '.jpg')
148
- st.title('Here is a sample image of ' + predicted_to_actual_dict[predicted_fish])
149
- reference_image = Image.open(PATH_fish)
150
- st.image(reference_image)
151
-
152
-
153
-
154
-
155
-
156
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/main.py DELETED
@@ -1,156 +0,0 @@
1
- import streamlit as st
2
- import torchvision
3
- from torchvision import transforms, datasets
4
- from PIL import Image, ImageOps
5
- import numpy as np
6
- from pathlib import Path
7
- import torch
8
- import torchvision.models as models
9
- from torchvision import transforms, datasets
10
- from torch.utils.data import DataLoader
11
- import torch.nn as nn
12
- from torch import optim
13
- import pandas as pd
14
-
15
- label_map = {
16
- 0: "ALB",
17
- 1: "BET",
18
- 2: "DOL",
19
- 3: "LAG",
20
- 4: "NoF",
21
- 5: "OTHER",
22
- 6: "SHARK",
23
- 7: "YFT",
24
- }
25
-
26
- label_list= [
27
- "ALB",
28
- "BET",
29
- "DOL",
30
- "LAG",
31
- "NoF",
32
- "OTHER",
33
- "SHARK",
34
- "YFT"
35
- ]
36
-
37
- predicted_to_actual_dict = {
38
- "ALB" : 'Albacore Tuna',
39
- "BET" : 'Bigeye Tuna',
40
- "DOL" : 'Dolphinfish, Mahi Mahi',
41
- "LAG" : 'Opah, Moonfish',
42
- "NoF" : 'No Fish',
43
- "OTHER" : 'Fish present but not in target categories',
44
- "SHARK" : 'Shark, including Silky & Shortfin Mako',
45
- "YFT" : 'Yellowfin Tuna'
46
- }
47
-
48
- fish_to_wiki = {
49
- 0: "https://en.wikipedia.org/wiki/Albacore",
50
- 1: "https://en.wikipedia.org/wiki/Bigeye_tuna",
51
- 2: "https://en.wikipedia.org/wiki/Mahi-mahi",
52
- 3: "https://en.wikipedia.org/wiki/Opah",
53
- 4: "https://en.wikipedia.org/wiki/Fish",
54
- 5: "https://en.wikipedia.org/wiki/Fish",
55
- 6: "https://en.wikipedia.org/wiki/Shark",
56
- 7: "https://en.wikipedia.org/wiki/Yellowfin_tuna",
57
- }
58
-
59
- MODEL_NAME = 'efficientnet'
60
-
61
- @st.cache()
62
- def augment_model(efficientnet):
63
- efficientnet.classifier[-1] = nn.Linear(in_features=1792, out_features=len(label_map), bias=True)
64
- return efficientnet
65
-
66
- with st.spinner('Model is being loaded..'):
67
- PATH = Path(__file__).resolve().parent.parent/'models'/'efficientnet_10_25_full.pt'
68
- st.write(PATH)
69
- # Use cuda to enable gpu usage for pytorch
70
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
- device = torch.device("cpu")
72
- # if MODEL_NAME in 'efficientnet':
73
- # efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b4', pretrained=True)
74
-
75
- # model_ft = augment_model(efficientnet)
76
- # model_ft.load_state_dict(torch.load(PATH,map_location=device))
77
-
78
- # else:
79
- st.write(torch.cuda.is_available())
80
- model_ft = torch.load(PATH,map_location=device)
81
-
82
- st.write("""
83
- # Endangered Fish Classification
84
- """
85
- )
86
- st.write('Nearly half of the world depends on seafood for their main source of protein. In the Western and Central Pacific, where 60% of the world’s tuna is caught, illegal, unreported, and unregulated fishing practices are threatening marine ecosystems, global seafood supplies and local livelihoods. The Nature Conservancy is working with local, regional and global partners to preserve this fishery for the future.')
87
- st.write('Currently, the Conservancy is looking to the future by using cameras to dramatically scale the monitoring of fishing activities to fill critical science and compliance monitoring data gaps. Our trained model helps to identify when target endangered species have been caught by fishermen.')
88
- file = st.file_uploader("Please upload your fish image", type=["jpg","png"])
89
-
90
-
91
- st.set_option('deprecation.showfileUploaderEncoding', False)
92
-
93
- @st.cache()
94
- def import_and_predict(image_data: Image.Image, model, k: int, index_to_label_dict: dict)-> list:
95
-
96
- if MODEL_NAME in 'vgg':
97
- transform = transforms.Compose([
98
- transforms.Resize((224, 224)),
99
- transforms.ToTensor(),
100
- transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
101
- )
102
- if MODEL_NAME in 'resnet' or MODEL_NAME in 'alexnet' or MODEL_NAME in 'efficientnet':
103
- transform = transforms.Compose([
104
- transforms.Resize(256),
105
- transforms.CenterCrop(224),
106
- transforms.ToTensor(),
107
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]
108
- )
109
-
110
- actual_img = transform(image_data).to(device)
111
- actual_img = actual_img.unsqueeze(0) # add one dimension to the front to account for batch_size
112
-
113
- formatted_predictions = model(actual_img)
114
- return formatted_predictions
115
-
116
- if file is None:
117
- pass
118
- else:
119
- image = Image.open(file)
120
-
121
- st.image(image, use_column_width=True)
122
-
123
- model_ft.eval()
124
- predictions = import_and_predict(image, model_ft, k = 3, index_to_label_dict = label_map)
125
-
126
- predicted_fish = label_map[int(torch.argmax(predictions))]
127
- normalised_list = torch.nn.functional.softmax(predictions, dim = 1)
128
- values, indices = torch.topk(normalised_list, 3)
129
-
130
- st.title('The predicted fish is: ' + predicted_to_actual_dict[predicted_fish])
131
-
132
- st.title('Here are the three most likely fish species(click for more info!)')
133
- df = pd.DataFrame(data=np.zeros((3, 2)),
134
- columns=['Species', 'Confidence Level'],
135
- index=np.linspace(1, 3, 3, dtype=int))
136
-
137
- # print(values.detach().numpy()[0][1])
138
-
139
- for count, i in enumerate(values.detach().numpy()[0]):
140
- x = int(indices.detach().numpy()[0][count])
141
- df.iloc[count, 0] = f'<a href="{fish_to_wiki[x]}" target="_blank">{predicted_to_actual_dict[label_map[x]].title()}</a>'
142
- df.iloc[count, 1] = np.format_float_positional(i, precision=8)
143
-
144
- st.write(df.to_html(escape=False), unsafe_allow_html=True)
145
- if predicted_fish not in ['OTHER', 'Nof']:
146
-
147
- PATH_fish = Path(__file__).resolve().parent/'data'/'fishes_ref'/ (predicted_fish + '.jpg')
148
- st.title('Here is a sample image of ' + predicted_to_actual_dict[predicted_fish])
149
- reference_image = Image.open(PATH_fish)
150
- st.image(reference_image)
151
-
152
-
153
-
154
-
155
-
156
-