File size: 3,742 Bytes
e6d6732
 
 
 
 
 
 
 
c9a01bf
e6d6732
 
 
 
a15d397
c9a01bf
a15d397
c9a01bf
e6d6732
 
 
 
 
 
 
b215ba8
 
e6d6732
 
 
 
 
 
 
 
b215ba8
e6d6732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b215ba8
e6d6732
 
 
 
 
 
 
 
b215ba8
 
068a4c2
6beabe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6d6732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb455f8
e6d6732
 
 
 
6beabe4
8c5661c
e6d6732
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import tensorflow as tf
import cv2
import numpy as np
from itertools import islice
from PIL import Image, ImageDraw, ImageColor
from line_fit import LineFit
import random
import lwbna_unet as u

MAX_SELECTIONS = 8
input_shape = (MAX_SELECTIONS,256,256,2)
def load_model():
    unet = u.LWBNAUnet(1,128,8,4)
    unet.build(input_shape)
    unet.load_weights('royal-snowflake-2400.hdf5')
    return unet

model = load_model()
colors = list(ImageColor.colormap.keys())
linefit = LineFit(30, 0.3)

def get_blob_centroids(mask):
    centers = []
    # print(mask.dtype)
    # print(mask.shape)
    contours, hierarchies = cv2.findContours(
        mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
    for i in contours:
        M = cv2.moments(i)
        if M['m00'] != 0:
            cx = int(M['m10']/M['m00'])
            cy = int(M['m01']/M['m00'])
            centers.append([cx,cy])
            # print(cx,cy)
    return centers
    
def predict_mask(input_img, threshold):
    # unpack and reshape input
    im, mask = input_img["image"], input_img["mask"]
    mask = mask[:,:,0].astype(np.uint8)
    im = im.astype(np.float32)/256

    # get centroids to measure the fibers
    centers = get_blob_centroids(mask)

    # create a batch of input for the model
    batch = np.zeros([MAX_SELECTIONS,256,256,2], dtype=np.float32)

    for i, (cx, cy) in enumerate(islice(centers, MAX_SELECTIONS)):
        batch[i,:,:,0] = im
        batch[i,cy,cx,1] = 1.0
    
    pred = model.predict(batch, verbose=0).squeeze()
    # create a single image with the background and the foreground
    im = Image.fromarray(im*255).convert("RGBA")
    # m = Image.fromarray(pred[0]>threshold).convert("RGBA")
    # im = Image.blend(im, m, 0.5)
    imgd = ImageDraw.Draw(im)
    ds = []
    for p in islice(pred, len(centers)):
        d, lines = linefit.predict((p>threshold).astype(np.uint8)*255)
        ds.append(d)
        m = Image.fromarray(p>threshold)
        # imgd.bitmap([0,0], m, fill=random.choice(colors))
        for line in lines:
            imgd.line(line, fill ="blue", width = 1)
        
    return im, ds

demo = gr.Blocks()

with demo:
    with gr.Column():
        gr.Markdown("""
# Measure the diameter of fibers

## How to use this app

Select fibers that you want to measure. For each fiber click once to draw a blob. Only the center of each sketched blob is considered for the selection of the fiber. You will obtain the measurements for the selected fibers. 

## Input

An image and a selection point. The recommended pixel size for uploaded images is 256x256.

## Output
The measurements for the fiber that was selected.

## Examples

The first three examples were part of the rendered hold out set, the fourth is a real micrograph also never seen by the model during training.


""")
    with gr.Row():
        with gr.Column():
            img = gr.Image(
                tool="sketch", 
                source="upload",
                label="Mask",
                image_mode='L',
                shape=[256,256],
                value='test0000.png'
            )
            threshold = gr.Slider(
                label='Segmentation threshold', minimum=0, maximum=1, value=0.5)

            with gr.Row():
                btn = gr.Button("Run")
        with gr.Column():
            img2 = gr.Image()
            text = gr.Text(label='Measurements in px')

    btn.click(fn=predict_mask, inputs=[img, threshold], outputs=[img2,text], )
    examples = gr.Examples(examples=['test0001.png',
                                     'test0002.png',
                                     'test0003.png',
                                     'real_test0000.jpg'],
                           inputs=img)

demo.launch()