demo_unet / app.py
h2chen's picture
Update app.py
9e60d68
raw
history blame contribute delete
No virus
1.48 kB
from fastai.vision.all import *
from io import BytesIO
import requests
import streamlit as st
"""
# U-Net
This is a segmentation model for images of Brain MRI.
"""
def predict(img):
st.image(img, caption="Your image", use_column_width=True)
pred_mask = learn_inf.predict(img)[0]
#pred_mask = pred_mask.numpy()*255
pred_mask = pred_mask.numpy()
gray_ratio = int(255/pred_mask.max())
pred_mask = pred_mask*gray_ratio
f"""
### Prediction result:
"""
st.image(pred_mask, caption="Prediction Mask", use_column_width=True)
def acc_camvid(inp, targ):
void_code = 0
targ = targ.squeeze(1)
mask = targ != void_code
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()
def label_func(x):
#return x.parents[0] / (x.stem + '_mask' + x.suffix)
return x.parents[0] / (x.stem + '_mask' + '.png')
path = "./"
learn_inf = load_learner(path + "model-34-stage2")
option = st.radio("", ["Upload Image", "Image URL"])
if option == "Upload Image":
uploaded_file = st.file_uploader("Please upload an image.")
if uploaded_file is not None:
img = PILImage.create(uploaded_file)
predict(img)
else:
url = st.text_input("Please input a url.")
if url != "":
try:
response = requests.get(url)
pil_img = PILImage.create(BytesIO(response.content))
predict(pil_img)
except:
st.text("Problem reading image from", url)