Spaces:
Sleeping
Sleeping
import gradio as gr | |
import monai | |
import torch | |
from monai.networks.nets import UNet | |
from PIL import Image | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import numpy as np | |
import shutil | |
import os | |
import openslide | |
from project_utils.preprocessing import expand2square | |
model = UNet( | |
spatial_dims=2, | |
in_channels=3, | |
out_channels=1, | |
channels=[16, 32, 64, 128, 256, 512], | |
strides=(2, 2, 2, 2, 2), | |
num_res_units=4, | |
dropout=0.15, | |
) | |
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu'))) | |
model.eval() | |
def process_image(image): | |
image = image / 255.0 | |
image = image.astype(np.float32) | |
inference_transforms = A.Compose([ | |
A.Resize(height=512, width=512), | |
ToTensorV2(), | |
]) | |
image = inference_transforms(image=image)["image"] | |
image = image.unsqueeze(0) | |
with torch.no_grad(): | |
mask_pred = torch.sigmoid(model(image)) | |
return mask_pred[0, 0, :, :].numpy() | |
interface_image = gr.Interface( | |
fn=process_image, | |
title="Histapathology segmentation", | |
inputs=[ | |
gr.Image( | |
label="Input image", | |
image_mode="RGB", | |
height=400, | |
type="numpy", | |
width=400, | |
) | |
], | |
outputs=[ | |
gr.Image( | |
label="Model Prediction", | |
image_mode="L", | |
height=400, | |
width=400, | |
) | |
], | |
# examples=[ | |
# os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"), | |
# os.path.join(os.path.dirname(__file__), "images/lion.jpg"), | |
# os.path.join(os.path.dirname(__file__), "images/logo.png"), | |
# os.path.join(os.path.dirname(__file__), "images/tower.jpg"), | |
# ], | |
) | |
def process_slide(slide_path): | |
if not slide_path.endswith("zip"): | |
slide = openslide.OpenSlide(os.path.join(slide_path)) | |
else: # mrxs slide files | |
shutil.unpack_archive(slide_path, "cache_mrxs") | |
files = os.listdir("cache_mrxs") | |
slide_name = [file for file in files if file.endswith("mrxs")][0] | |
slide = openslide.OpenSlide(os.path.join("cache_mrxs", slide_name)) | |
thumbnail = slide.get_thumbnail((512, 512)) | |
image = expand2square(thumbnail, "white") | |
return image, process_image(np.array(image)) | |
interface_slide = gr.Interface( | |
fn=process_slide, | |
inputs=[ | |
gr.File( | |
label="Input slide file (input zip for `.mrxs` files)", | |
) | |
], | |
outputs=[ | |
gr.Image( | |
label="Input Image", | |
image_mode="RGB", | |
height=400, | |
width=400, | |
), | |
gr.Image( | |
label="Model Prediction", | |
image_mode="L", | |
height=400, | |
width=400, | |
) | |
], | |
) | |
demo = gr.TabbedInterface([interface_image, interface_slide], ["Image-to-Mask", "Slide-to-Mask"]) | |
demo.launch() | |