|
import torch |
|
model = torch.hub.load('pytorch/vision:v0.9.0', 'fcn_resnet101', pretrained=True) |
|
model.eval() |
|
|
|
import urllib |
|
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") |
|
try: urllib.URLopener().retrieve(url, filename) |
|
except: urllib.request.urlretrieve(url, filename) |
|
|
|
from PIL import Image |
|
from torchvision import transforms |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def inference(input_image): |
|
preprocess = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
input_tensor = preprocess(input_image) |
|
input_batch = input_tensor.unsqueeze(0) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
input_batch = input_batch.to('cuda') |
|
model.to('cuda') |
|
|
|
with torch.no_grad(): |
|
output = model(input_batch)['out'][0] |
|
output_predictions = output.argmax(0) |
|
|
|
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) |
|
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette |
|
colors = (colors % 255).numpy().astype("uint8") |
|
|
|
|
|
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size) |
|
r.putpalette(colors) |
|
plt.imshow(r) |
|
return plt |
|
|
|
|
|
title = "FCN-RESNET101" |
|
description = "Gradio demo for FCN-RESNET101, Fully-Convolutional Network model with a ResNet-101 backbone. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1605.06211'>Fully Convolutional Networks for Semantic Segmentation</a> | <a href='https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/fcn.py'>Github Repo</a></p>" |
|
|
|
gr.Interface( |
|
inference, |
|
gr.inputs.Image(type="pil", label="Input"), |
|
gr.outputs.Image(type="plot", label="Output"), |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=[ |
|
["dog.jpg"] |
|
]).launch() |