|
from fastai.vision.all import * |
|
import gradio as gr |
|
|
|
p2c = {0: 255, 1: 76, 2: 150, 3: 119} |
|
|
|
def get_msk(fn, p2c): |
|
"Grab a mask from a `filename` and adjust the pixels based on `pix2class`" |
|
fn = path/'Mask1'/f'{fn.stem}_P{fn.suffix}' |
|
msk = np.array(PILMask.create(fn)) |
|
mx = np.max(msk) |
|
for i, val in enumerate(p2c): |
|
msk[msk==p2c[i]] = val |
|
return PILMask.create(msk) |
|
|
|
codes = np.array(['Cell', 'Cell Border', 'Mitochondria', 'Background']) |
|
|
|
def get_y(o): |
|
return get_msk(o, p2c) |
|
|
|
name2id = {v:k for k,v in enumerate(codes)} |
|
|
|
void_code = name2id['Background'] |
|
def acc_cells(inp, targ): |
|
targ = targ.squeeze(1) |
|
mask = targ != void_code |
|
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean() |
|
|
|
def segment_image(img): |
|
if isinstance(img, np.ndarray): |
|
img = Image.fromarray(img) |
|
|
|
img = img.convert("RGB") |
|
|
|
|
|
img = PILImage.create(img) |
|
results = learn.predict(img) |
|
|
|
|
|
mask = results[0].numpy() |
|
mask = (mask * 255).astype(np.uint8) |
|
|
|
|
|
mask_img = Image.fromarray(mask).convert("RGB").resize(img.size, Image.NEAREST) |
|
|
|
|
|
if mask_img.size != img.size: |
|
mask_img = mask_img.resize(img.size, Image.NEAREST) |
|
|
|
|
|
overlayed_img = Image.blend(img, mask_img, alpha=0.7) |
|
|
|
return overlayed_img |
|
|
|
|
|
def segment_image2(img): |
|
img = PILImage.create(img) |
|
results = learn.predict(img) |
|
|
|
|
|
mask = results[0].numpy() |
|
mask = (mask * 255).astype(np.uint8) |
|
|
|
|
|
|
|
class_names = np.array(['Cell', 'Cell Border', 'Mitochondria', 'Background']) |
|
|
|
|
|
num_classes = len(class_names) |
|
color_map = matplotlib.colormaps.get_cmap('tab20') |
|
|
|
|
|
class_colors = {} |
|
|
|
for i, name in enumerate(class_names): |
|
|
|
class_colors[name] = np.array(color_map(i / num_classes))[:3] * 255 |
|
|
|
|
|
color_map = [tuple(class_colors[name].astype(int)) for name in class_names] |
|
|
|
mask = results[0].numpy() |
|
|
|
|
|
rgb_mask = np.zeros((*mask.shape, 3), dtype=np.uint8) |
|
|
|
|
|
for class_idx, color in enumerate(color_map): |
|
rgb_mask[mask == class_idx] = color |
|
|
|
|
|
mask_img = Image.fromarray(rgb_mask).convert("RGB").resize(img.size, Image.NEAREST) |
|
|
|
|
|
if mask_img.size != img.size: |
|
mask_img = mask_img.resize(img.size, Image.NEAREST) |
|
|
|
|
|
overlayed_img = Image.blend(img, mask_img, alpha=0.7) |
|
return overlayed_img |
|
|
|
learn = load_learner('export.pkl') |
|
image = gr.components.Image() |
|
mask_img = gr.components.Image() |
|
examples=['Cell-142.png','Image 1.png','Image 1.png','Image 2.png','Image 3.png','Image 4.png','Image 5.png', |
|
'Image 6.png','Image 7.png','Image 8.png','Image 9.png','Image 10.png'] |
|
demo = gr.Interface(fn=segment_image2,inputs=image,outputs=mask_img,examples=examples) |
|
|
|
demo.launch() |
|
|