fcossio's picture
Update model
a15d397
raw
history blame
3.27 kB
import gradio as gr
import tensorflow as tf
import cv2
import numpy as np
from itertools import islice
from PIL import Image, ImageDraw, ImageColor
from line_fit import LineFit
import random
import lwbna_unet as u
MAX_SELECTIONS = 8
input_shape = (MAX_SELECTIONS,256,256,2)
def load_model():
unet = u.LWBNAUnet(1,128,8,4)
unet.build(input_shape)
unet.load_weights('royal-snowflake-2400.hdf5')
return unet
model = load_model()
colors = list(ImageColor.colormap.keys())
linefit = LineFit(30, 0.3)
def get_blob_centroids(mask):
centers = []
# print(mask.dtype)
# print(mask.shape)
contours, hierarchies = cv2.findContours(
mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for i in contours:
M = cv2.moments(i)
if M['m00'] != 0:
cx = int(M['m10']/M['m00'])
cy = int(M['m01']/M['m00'])
centers.append([cx,cy])
# print(cx,cy)
return centers
def predict_mask(input_img, threshold):
# unpack and reshape input
im, mask = input_img["image"], input_img["mask"]
mask = mask[:,:,0].astype(np.uint8)
im = im.astype(np.float32)/256
# get centroids to measure the fibers
centers = get_blob_centroids(mask)
# create a batch of input for the model
batch = np.zeros([MAX_SELECTIONS,256,256,2], dtype=np.float32)
for i, (cx, cy) in enumerate(islice(centers, MAX_SELECTIONS)):
batch[i,:,:,0] = im
batch[i,cy,cx,1] = 1.0
pred = model.predict(batch, verbose=0).squeeze()
# create a single image with the background and the foreground
im = Image.fromarray(im*255).convert("RGBA")
# m = Image.fromarray(pred[0]>threshold).convert("RGBA")
# im = Image.blend(im, m, 0.5)
imgd = ImageDraw.Draw(im)
ds = []
for p in islice(pred, len(centers)):
d, lines = linefit.predict((p>threshold).astype(np.uint8)*255)
ds.append(d)
m = Image.fromarray(p>threshold)
# imgd.bitmap([0,0], m, fill=random.choice(colors))
for line in lines:
imgd.line(line, fill ="blue", width = 1)
return im, ds
demo = gr.Blocks()
with demo:
with gr.Column():
gr.Markdown("""
# Measure the diameter of fibers
Select fibers that you want to measure. Only the center of each sketched blob is considered.
The recommended pixel size for uploaded images is 256x256.
""")
with gr.Row():
with gr.Column():
img = gr.Image(
tool="sketch",
source="upload",
label="Mask",
image_mode='L',
shape=[256,256],
value='test0000.png'
)
threshold = gr.Slider(
label='Segmentation threshold', minimum=0, maximum=1, value=0.5)
with gr.Row():
btn = gr.Button("Run")
with gr.Column():
img2 = gr.Image()
text = gr.Text(label='Measurements in px')
btn.click(fn=predict_mask, inputs=[img, threshold], outputs=[img2,text], )
examples = gr.Examples(examples=['test0001.png',
'test0002.png',
'test0003.png'],
inputs=img)
demo.launch()