hasibzunair's picture
change arg type
92bdcd3
import os
import numpy as np
import codecs
import torch
import torchvision.transforms as transforms
import gradio as gr
from PIL import Image
from unetplusplus import NestedUNet
torch.manual_seed(0)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
# Device
DEVICE = "cpu"
print(DEVICE)
# Load color map
cmap = np.load("cmap.npy")
# Make directories
os.system("mkdir ./models")
# Get model weights
if not os.path.exists("./models/masksupnyu39.31d.pth"):
os.system(
"wget -O ./models/masksupnyu39.31d.pth https://github.com/hasibzunair/masksup-segmentation/releases/download/v0.1/masksupnyu39.31iou.pth"
)
# Load model
model = NestedUNet(num_classes=40)
checkpoint = torch.load(
"./models/masksupnyu39.31d.pth", map_location=torch.device("cpu")
)
model.load_state_dict(checkpoint)
model = model.to(DEVICE)
model.eval()
# Main inference function
def inference(img_path):
image = Image.open(img_path).convert("RGB")
transforms_image = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
image = transforms_image(image)
image = image[None, :] # batch dimension
# Predict
with torch.no_grad():
output = torch.sigmoid(model(image.to(DEVICE).float()))
output = (
torch.softmax(output, dim=1)
.argmax(dim=1)[0]
.float()
.cpu()
.numpy()
.astype(np.uint8)
)
pred = cmap[output]
return pred
# App
title = "Masked Supervised Learning for Semantic Segmentation"
description = codecs.open("description.html", "r", "utf-8").read()
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2210.00923' target='_blank'>Masked Supervised Learning for Semantic Segmentation</a> | <a href='https://github.com/hasibzunair/masksup-segmentation' target='_blank'>Github</a></p>"
gr.Interface(
inference,
gr.inputs.Image(type="filepath", label="Input Image"),
gr.outputs.Image(type="numpy", label="Predicted Output"),
examples=[
"./sample_images/a.png",
"./sample_images/b.png",
"./sample_images/c.png",
"./sample_images/d.png",
],
title=title,
description=description,
article=article,
allow_flagging=False,
analytics_enabled=False,
).launch(debug=True, enable_queue=True)