Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import codecs | |
| import torch | |
| import torchvision.transforms as transforms | |
| import gradio as gr | |
| from PIL import Image | |
| from unetplusplus import NestedUNet | |
| torch.manual_seed(0) | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.deterministic = True | |
| # Device | |
| DEVICE = "cpu" | |
| print(DEVICE) | |
| # Load color map | |
| cmap = np.load("cmap.npy") | |
| # Make directories | |
| os.system("mkdir ./models") | |
| # Get model weights | |
| if not os.path.exists("./models/masksupnyu39.31d.pth"): | |
| os.system( | |
| "wget -O ./models/masksupnyu39.31d.pth https://github.com/hasibzunair/masksup-segmentation/releases/download/v0.1/masksupnyu39.31iou.pth" | |
| ) | |
| # Load model | |
| model = NestedUNet(num_classes=40) | |
| checkpoint = torch.load( | |
| "./models/masksupnyu39.31d.pth", map_location=torch.device("cpu") | |
| ) | |
| model.load_state_dict(checkpoint) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| # Main inference function | |
| def inference(img_path): | |
| image = Image.open(img_path).convert("RGB") | |
| transforms_image = transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.CenterCrop((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| ] | |
| ) | |
| image = transforms_image(image) | |
| image = image[None, :] # batch dimension | |
| # Predict | |
| with torch.no_grad(): | |
| output = torch.sigmoid(model(image.to(DEVICE).float())) | |
| output = ( | |
| torch.softmax(output, dim=1) | |
| .argmax(dim=1)[0] | |
| .float() | |
| .cpu() | |
| .numpy() | |
| .astype(np.uint8) | |
| ) | |
| pred = cmap[output] | |
| return pred | |
| # App | |
| title = "Masked Supervised Learning for Semantic Segmentation" | |
| description = codecs.open("description.html", "r", "utf-8").read() | |
| article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2210.00923' target='_blank'>Masked Supervised Learning for Semantic Segmentation</a> | <a href='https://github.com/hasibzunair/masksup-segmentation' target='_blank'>Github</a></p>" | |
| gr.Interface( | |
| inference, | |
| gr.inputs.Image(type="filepath", label="Input Image"), | |
| gr.outputs.Image(type="numpy", label="Predicted Output"), | |
| examples=[ | |
| "./sample_images/a.png", | |
| "./sample_images/b.png", | |
| "./sample_images/c.png", | |
| "./sample_images/d.png", | |
| ], | |
| title=title, | |
| description=description, | |
| article=article, | |
| allow_flagging=False, | |
| analytics_enabled=False, | |
| ).launch(debug=True, enable_queue=True) | |