SegmHTI / app.py
itahir's picture
added new test images
f415e23
from fastai.vision.all import *
import gradio as gr
p2c = {0: 255, 1: 76, 2: 150, 3: 119}
def get_msk(fn, p2c):
"Grab a mask from a `filename` and adjust the pixels based on `pix2class`"
fn = path/'Mask1'/f'{fn.stem}_P{fn.suffix}'
msk = np.array(PILMask.create(fn))
mx = np.max(msk)
for i, val in enumerate(p2c):
msk[msk==p2c[i]] = val
return PILMask.create(msk)
codes = np.array(['Cell', 'Cell Border', 'Mitochondria', 'Background'])
def get_y(o):
return get_msk(o, p2c)
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Background']
def acc_cells(inp, targ):
targ = targ.squeeze(1)
mask = targ != void_code
return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean()
def segment_image(img):
if isinstance(img, np.ndarray):
img = Image.fromarray(img)
# Convert the input image to RGB mode
img = img.convert("RGB")
# Create PIL image from the input and get the segmentation results
img = PILImage.create(img)
results = learn.predict(img)
# Get the mask and normalize it
mask = results[0].numpy()
mask = (mask * 255).astype(np.uint8)
# Convert mask to RGB and ensure it matches the size of the original image
mask_img = Image.fromarray(mask).convert("RGB").resize(img.size, Image.NEAREST)
# Ensure the mask image has the same size as the original image
if mask_img.size != img.size:
mask_img = mask_img.resize(img.size, Image.NEAREST)
# Overlay the mask on the original image
overlayed_img = Image.blend(img, mask_img, alpha=0.7)
return overlayed_img
def segment_image2(img):
img = PILImage.create(img)
results = learn.predict(img)
# Get the mask and normalize it
mask = results[0].numpy()
mask = (mask * 255).astype(np.uint8)
# Load the class names
class_names = np.array(['Cell', 'Cell Border', 'Mitochondria', 'Background'])
# Create a color map
num_classes = len(class_names)
color_map = matplotlib.colormaps.get_cmap('tab20') # Use a colormap with enough colors
# Create a dictionary to store colors for each class
class_colors = {}
for i, name in enumerate(class_names):
# Use the colormap to assign a unique color to each class
class_colors[name] = np.array(color_map(i / num_classes))[:3] * 255
# Convert to a list of RGB tuples
color_map = [tuple(class_colors[name].astype(int)) for name in class_names]
# Convert the mask to a numpy array if it's not already
mask = results[0].numpy()
# Create an empty RGB image
rgb_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
# Fill in the RGB image
for class_idx, color in enumerate(color_map):
rgb_mask[mask == class_idx] = color
mask_img = Image.fromarray(rgb_mask).convert("RGB").resize(img.size, Image.NEAREST)
# Ensure the mask image has the same size as the original image
if mask_img.size != img.size:
mask_img = mask_img.resize(img.size, Image.NEAREST)
# Overlay the mask on the original image
overlayed_img = Image.blend(img, mask_img, alpha=0.7)
return overlayed_img
learn = load_learner('export.pkl')
image = gr.components.Image()
mask_img = gr.components.Image()
examples=['Cell-142.png','Image 1.png','Image 1.png','Image 2.png','Image 3.png','Image 4.png','Image 5.png',
'Image 6.png','Image 7.png','Image 8.png','Image 9.png','Image 10.png']
demo = gr.Interface(fn=segment_image2,inputs=image,outputs=mask_img,examples=examples)
demo.launch()