File size: 3,885 Bytes
c923f4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import streamlit as st
import os
import urllib
import fastai.vision.all as fai_vision
import numpy as np
from pathlib import Path
import pathlib
from PIL import Image
import platform
import altair as alt
import pandas as pd

def main():
    st.title('Fish Masker and Classifier')
   
    data_loader, segmenter = load_unet_model()
    classification_model = load_classification_model()
    
    st.markdown("Upload an Amazonian fish photo for masking.")
    uploaded_image = st.file_uploader("", IMAGE_TYPES)
    if uploaded_image:
        image_data = uploaded_image.read()
        st.markdown('## Original image')
        st.image(image_data, use_column_width=True)

        original_pil = Image.open(uploaded_image)

        original_pil.save('original.jpg')

        single_file = [Path('original.jpg')]
        single_pil = Image.open(single_file[0])
        input_dl = segmenter.dls.test_dl(single_file)
        masks, _ = segmenter.get_preds(dl=input_dl)
        masked_pil, percentage_fish = mask_fish_pil(single_pil, masks[0])

        st.markdown('## Masked image')
        st.markdown(f'**{percentage_fish:.1f}%** of pixels were labeled as "fish"')
        st.image(masked_pil, use_column_width=True)

        masked_pil.save('masked.jpg')

        st.markdown('## Classification')

        prediction = classification_model.predict('masked.jpg')
        pred_chart = predictions_to_chart(prediction, classes = classification_model.dls.vocab)
        st.altair_chart(pred_chart, use_container_width=True)


def mask_fish_pil(unmasked_fish, fastai_mask):
    unmasked_np = np.array(unmasked_fish)
    np_mask = fastai_mask.argmax(dim=0).numpy()
    total_pixels = np_mask.size
    fish_pixels = np.count_nonzero(np_mask)
    percentage_fish = (fish_pixels / total_pixels) * 100
    np_mask = (255 / np_mask.max() * (np_mask - np_mask.min())).astype(np.uint8)
    np_mask = np.array(Image.fromarray(np_mask).resize(unmasked_np.shape[1::-1], Image.BILINEAR))
    np_mask = np_mask.reshape(*np_mask.shape, 1) / 255
    masked_fish_np = (unmasked_np * np_mask).astype(np.uint8)
    masked_fish_pil = Image.fromarray(masked_fish_np)
    return masked_fish_pil, percentage_fish

def predictions_to_chart(prediction, classes):
    pred_rows = []
    for i, conf in enumerate(list(prediction[2])):
        pred_row = {'class': classes[i],
                    'probability': round(float(conf) * 100,2)}
        pred_rows.append(pred_row)
    pred_df = pd.DataFrame(pred_rows)
    pred_df.head()
    top_probs = pred_df.sort_values('probability', ascending=False).head(4)
    chart = (
        alt.Chart(top_probs)
        .mark_bar()
        .encode(
            x=alt.X("probability:Q", scale=alt.Scale(domain=(0, 100))),
            y=alt.Y("class:N",
                    sort=alt.EncodingSortField(field="probability", order="descending"))
        )
    )
    return chart

@st.cache(allow_output_mutation=True)
def load_unet_model():
    data_loader = fai_vision.SegmentationDataLoaders.from_label_func(
        path = Path("."),
        bs = 1,
        fnames = [Path('test_fish.jpg')],
        label_func = lambda x: x,
        codes = np.array(["Photo", "Masks"], dtype=str),
        item_tfms = [fai_vision.Resize(256, method = 'squish'),],
        batch_tfms = [fai_vision.IntToFloatTensor(div_mask = 255)],
        valid_pct = 0.2, num_workers = 0)
    segmenter = fai_vision.unet_learner(data_loader, fai_vision.resnet34)
    segmenter.load('fish_mask_model')
    return data_loader, segmenter

@st.cache(allow_output_mutation=True)
def load_classification_model():
    plt = platform.system()

    if plt == 'Linux' or plt == 'Darwin': 
        pathlib.WindowsPath = pathlib.PosixPath
    inf_model = fai_vision.load_learner('models/fish_classification_model.pkl', cpu=True)

    return inf_model

IMAGE_TYPES = ["png", "jpg","jpeg"]

if __name__ == "__main__":
    main()