|
import os |
|
import numpy as np |
|
import codecs |
|
import torch |
|
import torchvision.transforms as transforms |
|
import gradio as gr |
|
|
|
from PIL import Image |
|
|
|
from unetplusplus import NestedUNet |
|
|
|
torch.manual_seed(0) |
|
|
|
if torch.cuda.is_available(): |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
DEVICE = "cpu" |
|
print(DEVICE) |
|
|
|
|
|
cmap = np.load("cmap.npy") |
|
|
|
|
|
os.system("mkdir ./models") |
|
|
|
|
|
if not os.path.exists("./models/masksupnyu39.31d.pth"): |
|
os.system( |
|
"wget -O ./models/masksupnyu39.31d.pth https://github.com/hasibzunair/masksup-segmentation/releases/download/v0.1/masksupnyu39.31iou.pth" |
|
) |
|
|
|
|
|
model = NestedUNet(num_classes=40) |
|
checkpoint = torch.load( |
|
"./models/masksupnyu39.31d.pth", map_location=torch.device("cpu") |
|
) |
|
model.load_state_dict(checkpoint) |
|
model = model.to(DEVICE) |
|
model.eval() |
|
|
|
|
|
|
|
def inference(img_path): |
|
image = Image.open(img_path).convert("RGB") |
|
transforms_image = transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.CenterCrop((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
] |
|
) |
|
|
|
image = transforms_image(image) |
|
image = image[None, :] |
|
|
|
|
|
with torch.no_grad(): |
|
output = torch.sigmoid(model(image.to(DEVICE).float())) |
|
output = ( |
|
torch.softmax(output, dim=1) |
|
.argmax(dim=1)[0] |
|
.float() |
|
.cpu() |
|
.numpy() |
|
.astype(np.uint8) |
|
) |
|
pred = cmap[output] |
|
return pred |
|
|
|
|
|
|
|
title = "Masked Supervised Learning for Semantic Segmentation" |
|
description = codecs.open("description.html", "r", "utf-8").read() |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2210.00923' target='_blank'>Masked Supervised Learning for Semantic Segmentation</a> | <a href='https://github.com/hasibzunair/masksup-segmentation' target='_blank'>Github</a></p>" |
|
|
|
gr.Interface( |
|
inference, |
|
gr.inputs.Image(type="filepath", label="Input Image"), |
|
gr.outputs.Image(type="numpy", label="Predicted Output"), |
|
examples=[ |
|
"./sample_images/a.png", |
|
"./sample_images/b.png", |
|
"./sample_images/c.png", |
|
"./sample_images/d.png", |
|
], |
|
title=title, |
|
description=description, |
|
article=article, |
|
allow_flagging=False, |
|
analytics_enabled=False, |
|
).launch(debug=True, enable_queue=True) |
|
|