|
import streamlit as st |
|
import os |
|
import urllib |
|
import fastai.vision.all as fai_vision |
|
import numpy as np |
|
from pathlib import Path |
|
import pathlib |
|
from PIL import Image |
|
import platform |
|
import altair as alt |
|
import pandas as pd |
|
import frontmatter |
|
|
|
def main(): |
|
st.title('Fish Masker and Classifier') |
|
|
|
with open('README.md') as readme_file: |
|
readme = frontmatter.load(readme_file) |
|
st.markdown(readme.content) |
|
|
|
data_loader, segmenter = load_unet_model() |
|
classification_model = load_classification_model() |
|
|
|
st.markdown("## Instructions") |
|
st.markdown("Upload an Amazonian fish photo for masking.") |
|
uploaded_image = st.file_uploader("", IMAGE_TYPES) |
|
if uploaded_image: |
|
image_data = uploaded_image.read() |
|
st.markdown('## Original image') |
|
st.image(image_data, use_column_width=True) |
|
|
|
original_pil = Image.open(uploaded_image) |
|
|
|
original_pil.save('original.jpg') |
|
|
|
single_file = [Path('original.jpg')] |
|
single_pil = Image.open(single_file[0]) |
|
input_dl = segmenter.dls.test_dl(single_file) |
|
masks, _ = segmenter.get_preds(dl=input_dl) |
|
masked_pil, percentage_fish = mask_fish_pil(single_pil, masks[0]) |
|
|
|
st.markdown('## Masked image') |
|
st.markdown(f'**{percentage_fish:.1f}%** of pixels were labeled as "fish"') |
|
st.image(masked_pil, use_column_width=True) |
|
|
|
masked_pil.save('masked.jpg') |
|
|
|
st.markdown('## Classification') |
|
|
|
prediction = classification_model.predict('masked.jpg') |
|
pred_chart = predictions_to_chart(prediction, classes = classification_model.dls.vocab) |
|
st.altair_chart(pred_chart, use_container_width=True) |
|
|
|
|
|
def mask_fish_pil(unmasked_fish, fastai_mask): |
|
unmasked_np = np.array(unmasked_fish) |
|
np_mask = fastai_mask.argmax(dim=0).numpy() |
|
total_pixels = np_mask.size |
|
fish_pixels = np.count_nonzero(np_mask) |
|
percentage_fish = (fish_pixels / total_pixels) * 100 |
|
np_mask = (255 / np_mask.max() * (np_mask - np_mask.min())).astype(np.uint8) |
|
np_mask = np.array(Image.fromarray(np_mask).resize(unmasked_np.shape[1::-1], Image.BILINEAR)) |
|
np_mask = np_mask.reshape(*np_mask.shape, 1) / 255 |
|
masked_fish_np = (unmasked_np * np_mask).astype(np.uint8) |
|
masked_fish_pil = Image.fromarray(masked_fish_np) |
|
return masked_fish_pil, percentage_fish |
|
|
|
def predictions_to_chart(prediction, classes): |
|
pred_rows = [] |
|
for i, conf in enumerate(list(prediction[2])): |
|
pred_row = {'class': classes[i], |
|
'probability': round(float(conf) * 100,2)} |
|
pred_rows.append(pred_row) |
|
pred_df = pd.DataFrame(pred_rows) |
|
pred_df.head() |
|
top_probs = pred_df.sort_values('probability', ascending=False).head(4) |
|
chart = ( |
|
alt.Chart(top_probs) |
|
.mark_bar() |
|
.encode( |
|
x=alt.X("probability:Q", scale=alt.Scale(domain=(0, 100))), |
|
y=alt.Y("class:N", |
|
sort=alt.EncodingSortField(field="probability", order="descending")) |
|
) |
|
) |
|
return chart |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_unet_model(): |
|
data_loader = fai_vision.SegmentationDataLoaders.from_label_func( |
|
path = Path("."), |
|
bs = 1, |
|
fnames = [Path('test_fish.jpg')], |
|
label_func = lambda x: x, |
|
codes = np.array(["Photo", "Masks"], dtype=str), |
|
item_tfms = [fai_vision.Resize(256, method = 'squish'),], |
|
batch_tfms = [fai_vision.IntToFloatTensor(div_mask = 255)], |
|
valid_pct = 0.2, num_workers = 0) |
|
segmenter = fai_vision.unet_learner(data_loader, fai_vision.resnet34) |
|
segmenter.load('fish_mask_model') |
|
return data_loader, segmenter |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_classification_model(): |
|
plt = platform.system() |
|
|
|
if plt == 'Linux' or plt == 'Darwin': |
|
pathlib.WindowsPath = pathlib.PosixPath |
|
inf_model = fai_vision.load_learner('models/fish_classification_model.pkl', cpu=True) |
|
|
|
return inf_model |
|
|
|
IMAGE_TYPES = ["png", "jpg","jpeg"] |
|
|
|
if __name__ == "__main__": |
|
main() |