import gradio as gr import tensorflow as tf import cv2 import numpy as np from itertools import islice from PIL import Image, ImageDraw, ImageColor from line_fit import LineFit import random import lwbna_unet as u MAX_SELECTIONS = 8 input_shape = (MAX_SELECTIONS,256,256,2) def load_model(): unet = u.LWBNAUnet(1,128,8,4) unet.build(input_shape) unet.load_weights('royal-snowflake-2400.hdf5') return unet model = load_model() colors = list(ImageColor.colormap.keys()) linefit = LineFit(30, 0.3) def get_blob_centroids(mask): centers = [] # print(mask.dtype) # print(mask.shape) contours, hierarchies = cv2.findContours( mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) for i in contours: M = cv2.moments(i) if M['m00'] != 0: cx = int(M['m10']/M['m00']) cy = int(M['m01']/M['m00']) centers.append([cx,cy]) # print(cx,cy) return centers def predict_mask(input_img, threshold): # unpack and reshape input im, mask = input_img["image"], input_img["mask"] mask = mask[:,:,0].astype(np.uint8) im = im.astype(np.float32)/256 # get centroids to measure the fibers centers = get_blob_centroids(mask) # create a batch of input for the model batch = np.zeros([MAX_SELECTIONS,256,256,2], dtype=np.float32) for i, (cx, cy) in enumerate(islice(centers, MAX_SELECTIONS)): batch[i,:,:,0] = im batch[i,cy,cx,1] = 1.0 pred = model.predict(batch, verbose=0).squeeze() # create a single image with the background and the foreground im = Image.fromarray(im*255).convert("RGBA") # m = Image.fromarray(pred[0]>threshold).convert("RGBA") # im = Image.blend(im, m, 0.5) imgd = ImageDraw.Draw(im) ds = [] for p in islice(pred, len(centers)): d, lines = linefit.predict((p>threshold).astype(np.uint8)*255) ds.append(d) m = Image.fromarray(p>threshold) # imgd.bitmap([0,0], m, fill=random.choice(colors)) for line in lines: imgd.line(line, fill ="blue", width = 1) return im, ds demo = gr.Blocks() with demo: with gr.Column(): gr.Markdown(""" # Measure the diameter of fibers ## How to use this app Select fibers that you want to measure. For each fiber click once to draw a blob. Only the center of each sketched blob is considered for the selection of the fiber. You will obtain the measurements for the selected fibers. ## Input An image and a selection point. The recommended pixel size for uploaded images is 256x256. ## Output The measurements for the fiber that was selected. ## Examples The first three examples were part of the rendered hold out set, the fourth is a real micrograph also never seen by the model during training. """) with gr.Row(): with gr.Column(): img = gr.Image( tool="sketch", source="upload", label="Mask", image_mode='L', shape=[256,256], value='test0000.png' ) threshold = gr.Slider( label='Segmentation threshold', minimum=0, maximum=1, value=0.5) with gr.Row(): btn = gr.Button("Run") with gr.Column(): img2 = gr.Image() text = gr.Text(label='Measurements in px') btn.click(fn=predict_mask, inputs=[img, threshold], outputs=[img2,text], ) examples = gr.Examples(examples=['test0001.png', 'test0002.png', 'test0003.png', 'real_test0000.jpg'], inputs=img) demo.launch()