MikeTrizna commited on
Commit
4c907a8
โ€ข
1 Parent(s): b201d17

Initial commit of app. Directly copied from miketrizna/amazonian_fish_classifier

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [client]
2
+ showErrorDetails = false
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Amazonian Fish Classifier
3
- emoji: ๐Ÿ‘€
4
  colorFrom: green
5
  colorTo: pink
6
  sdk: streamlit
@@ -10,4 +10,8 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: Amazonian Fish Classifier
3
+ emoji: ๐Ÿ 
4
  colorFrom: green
5
  colorTo: pink
6
  sdk: streamlit
 
10
  license: mit
11
  ---
12
 
13
+ This is a demonstration app of the two machine learning models described in the paper:
14
+
15
+ > Robillard, A., Trizna, M. G., Ruiz-Tafur, K., Panduro, E. D., de Santana, C. D., White, A. E., Dikow, R. B., Deichmann, J. 2023. Application of a Deep Learning Image Classifier for Identification of Amazonian Fishes. *Ecology and Evolution* [https://doi.org/10.1002/ece3.9987](https://doi.org/10.1002/ece3.9987)
16
+
17
+ The models weights and image data are available on FigShare at [https://doi.org/10.25573/data.c.5761097.v1](https://doi.org/10.25573/data.c.5761097.v1)
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import urllib
4
+ import fastai.vision.all as fai_vision
5
+ import numpy as np
6
+ from pathlib import Path
7
+ import pathlib
8
+ from PIL import Image
9
+ import platform
10
+ import altair as alt
11
+ import pandas as pd
12
+ import frontmatter
13
+
14
+ def main():
15
+ st.title('Fish Masker and Classifier')
16
+
17
+ with open('README.md') as readme_file:
18
+ readme = frontmatter.load(readme_file)
19
+ st.markdown(readme.content)
20
+
21
+ data_loader, segmenter = load_unet_model()
22
+ classification_model = load_classification_model()
23
+
24
+ st.markdown("## Instructions")
25
+ st.markdown("Upload an Amazonian fish photo for masking.")
26
+ uploaded_image = st.file_uploader("", IMAGE_TYPES)
27
+ if uploaded_image:
28
+ image_data = uploaded_image.read()
29
+ st.markdown('## Original image')
30
+ st.image(image_data, use_column_width=True)
31
+
32
+ original_pil = Image.open(uploaded_image)
33
+
34
+ original_pil.save('original.jpg')
35
+
36
+ single_file = [Path('original.jpg')]
37
+ single_pil = Image.open(single_file[0])
38
+ input_dl = segmenter.dls.test_dl(single_file)
39
+ masks, _ = segmenter.get_preds(dl=input_dl)
40
+ masked_pil, percentage_fish = mask_fish_pil(single_pil, masks[0])
41
+
42
+ st.markdown('## Masked image')
43
+ st.markdown(f'**{percentage_fish:.1f}%** of pixels were labeled as "fish"')
44
+ st.image(masked_pil, use_column_width=True)
45
+
46
+ masked_pil.save('masked.jpg')
47
+
48
+ st.markdown('## Classification')
49
+
50
+ prediction = classification_model.predict('masked.jpg')
51
+ pred_chart = predictions_to_chart(prediction, classes = classification_model.dls.vocab)
52
+ st.altair_chart(pred_chart, use_container_width=True)
53
+
54
+
55
+ def mask_fish_pil(unmasked_fish, fastai_mask):
56
+ unmasked_np = np.array(unmasked_fish)
57
+ np_mask = fastai_mask.argmax(dim=0).numpy()
58
+ total_pixels = np_mask.size
59
+ fish_pixels = np.count_nonzero(np_mask)
60
+ percentage_fish = (fish_pixels / total_pixels) * 100
61
+ np_mask = (255 / np_mask.max() * (np_mask - np_mask.min())).astype(np.uint8)
62
+ np_mask = np.array(Image.fromarray(np_mask).resize(unmasked_np.shape[1::-1], Image.BILINEAR))
63
+ np_mask = np_mask.reshape(*np_mask.shape, 1) / 255
64
+ masked_fish_np = (unmasked_np * np_mask).astype(np.uint8)
65
+ masked_fish_pil = Image.fromarray(masked_fish_np)
66
+ return masked_fish_pil, percentage_fish
67
+
68
+ def predictions_to_chart(prediction, classes):
69
+ pred_rows = []
70
+ for i, conf in enumerate(list(prediction[2])):
71
+ pred_row = {'class': classes[i],
72
+ 'probability': round(float(conf) * 100,2)}
73
+ pred_rows.append(pred_row)
74
+ pred_df = pd.DataFrame(pred_rows)
75
+ pred_df.head()
76
+ top_probs = pred_df.sort_values('probability', ascending=False).head(4)
77
+ chart = (
78
+ alt.Chart(top_probs)
79
+ .mark_bar()
80
+ .encode(
81
+ x=alt.X("probability:Q", scale=alt.Scale(domain=(0, 100))),
82
+ y=alt.Y("class:N",
83
+ sort=alt.EncodingSortField(field="probability", order="descending"))
84
+ )
85
+ )
86
+ return chart
87
+
88
+ @st.cache(allow_output_mutation=True)
89
+ def load_unet_model():
90
+ data_loader = fai_vision.SegmentationDataLoaders.from_label_func(
91
+ path = Path("."),
92
+ bs = 1,
93
+ fnames = [Path('test_fish.jpg')],
94
+ label_func = lambda x: x,
95
+ codes = np.array(["Photo", "Masks"], dtype=str),
96
+ item_tfms = [fai_vision.Resize(256, method = 'squish'),],
97
+ batch_tfms = [fai_vision.IntToFloatTensor(div_mask = 255)],
98
+ valid_pct = 0.2, num_workers = 0)
99
+ segmenter = fai_vision.unet_learner(data_loader, fai_vision.resnet34)
100
+ segmenter.load('fish_mask_model')
101
+ return data_loader, segmenter
102
+
103
+ @st.cache(allow_output_mutation=True)
104
+ def load_classification_model():
105
+ plt = platform.system()
106
+
107
+ if plt == 'Linux' or plt == 'Darwin':
108
+ pathlib.WindowsPath = pathlib.PosixPath
109
+ inf_model = fai_vision.load_learner('models/fish_classification_model.pkl', cpu=True)
110
+
111
+ return inf_model
112
+
113
+ IMAGE_TYPES = ["png", "jpg","jpeg"]
114
+
115
+ if __name__ == "__main__":
116
+ main()
models/fish_classification_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ac16550590dd60da201ce13e2f1b057d5343ef490db8663c463f8bbefef610e
3
+ size 179319095
models/fish_mask_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29b8afc516eb9f19e99dc53e924839a7157ac241d13f0945aec4717574c7908a
3
+ size 494929527
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit==0.89
2
+ fastai==2.2
3
+ protobuf==3.20
4
+ altair
5
+ pandas
6
+ frontmatter
test_fish.jpg ADDED