model-inference / app.py
osbm's picture
Update app.py
4cc4c48 verified
import gradio as gr
import monai
import torch
from monai.networks.nets import UNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import shutil
import os
import openslide
from project_utils.preprocessing import expand2square
model = UNet(
spatial_dims=2,
in_channels=3,
out_channels=1,
channels=[16, 32, 64, 128, 256, 512],
strides=(2, 2, 2, 2, 2),
num_res_units=4,
dropout=0.15,
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
model.eval()
def process_image(image):
image = image / 255.0
image = image.astype(np.float32)
inference_transforms = A.Compose([
A.Resize(height=512, width=512),
ToTensorV2(),
])
image = inference_transforms(image=image)["image"]
image = image.unsqueeze(0)
with torch.no_grad():
mask_pred = torch.sigmoid(model(image))
return mask_pred[0, 0, :, :].numpy()
interface_image = gr.Interface(
fn=process_image,
title="Histapathology segmentation",
inputs=[
gr.Image(
label="Input image",
image_mode="RGB",
height=400,
type="numpy",
width=400,
)
],
outputs=[
gr.Image(
label="Model Prediction",
image_mode="L",
height=400,
width=400,
)
],
# examples=[
# os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
# os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
# os.path.join(os.path.dirname(__file__), "images/logo.png"),
# os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
# ],
)
def process_slide(slide_path):
if not slide_path.endswith("zip"):
slide = openslide.OpenSlide(os.path.join(slide_path))
else: # mrxs slide files
shutil.unpack_archive(slide_path, "cache_mrxs")
files = os.listdir("cache_mrxs")
slide_name = [file for file in files if file.endswith("mrxs")][0]
slide = openslide.OpenSlide(os.path.join("cache_mrxs", slide_name))
thumbnail = slide.get_thumbnail((512, 512))
image = expand2square(thumbnail, "white")
return image, process_image(np.array(image))
interface_slide = gr.Interface(
fn=process_slide,
inputs=[
gr.File(
label="Input slide file (input zip for `.mrxs` files)",
)
],
outputs=[
gr.Image(
label="Input Image",
image_mode="RGB",
height=400,
width=400,
),
gr.Image(
label="Model Prediction",
image_mode="L",
height=400,
width=400,
)
],
)
demo = gr.TabbedInterface([interface_image, interface_slide], ["Image-to-Mask", "Slide-to-Mask"])
demo.launch()