|
import os |
|
os.system("pip install git+https://github.com/zhanghang1989/PyTorch-Encoding/") |
|
os.system("pip install git+https://github.com/openai/CLIP.git") |
|
import torch |
|
import argparse |
|
import numpy as np |
|
from tqdm import tqdm |
|
from collections import OrderedDict |
|
import torch.nn.functional as F |
|
from torch.utils import data |
|
import torchvision.transforms as transform |
|
from torch.nn.parallel.scatter_gather import gather |
|
from additional_utils.models import LSeg_MultiEvalModule |
|
from modules.lseg_module import LSegModule |
|
import cv2 |
|
import math |
|
import types |
|
import functools |
|
import torchvision.transforms as torch_transforms |
|
import copy |
|
import itertools |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import clip |
|
from encoding.models.sseg import BaseNet |
|
import matplotlib as mpl |
|
import matplotlib.colors as mplc |
|
import matplotlib.figure as mplfigure |
|
import matplotlib.patches as mpatches |
|
from matplotlib.backends.backend_agg import FigureCanvasAgg |
|
from data import get_dataset |
|
import torchvision.transforms as transforms |
|
|
|
import gradio as gr |
|
|
|
model_name = "convnext_xlarge_in22k" |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
def get_new_pallete(num_cls): |
|
n = num_cls |
|
pallete = [0]*(n*3) |
|
for j in range(0,n): |
|
lab = j |
|
pallete[j*3+0] = 0 |
|
pallete[j*3+1] = 0 |
|
pallete[j*3+2] = 0 |
|
i = 0 |
|
while (lab > 0): |
|
pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i)) |
|
pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i)) |
|
pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i)) |
|
i = i + 1 |
|
lab >>= 3 |
|
return pallete |
|
|
|
def get_new_mask_pallete(npimg, new_palette, out_label_flag=False, labels=None): |
|
"""Get image color pallete for visualizing masks""" |
|
|
|
out_img = Image.fromarray(npimg.squeeze().astype('uint8')) |
|
out_img.putpalette(new_palette) |
|
|
|
if out_label_flag: |
|
assert labels is not None |
|
u_index = np.unique(npimg) |
|
patches = [] |
|
for i, index in enumerate(u_index): |
|
label = labels[index] |
|
cur_color = [new_palette[index * 3] / 255.0, new_palette[index * 3 + 1] / 255.0, new_palette[index * 3 + 2] / 255.0] |
|
red_patch = mpatches.Patch(color=cur_color, label=label) |
|
patches.append(red_patch) |
|
return out_img, patches |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(): |
|
class Options: |
|
def __init__(self): |
|
parser = argparse.ArgumentParser(description="PyTorch Segmentation") |
|
|
|
parser.add_argument( |
|
"--model", type=str, default="encnet", help="model name (default: encnet)" |
|
) |
|
parser.add_argument( |
|
"--backbone", |
|
type=str, |
|
default="clip_vitl16_384", |
|
help="backbone name (default: resnet50)", |
|
) |
|
parser.add_argument( |
|
"--dataset", |
|
type=str, |
|
default="ade20k", |
|
help="dataset name (default: pascal12)", |
|
) |
|
parser.add_argument( |
|
"--workers", type=int, default=16, metavar="N", help="dataloader threads" |
|
) |
|
parser.add_argument( |
|
"--base-size", type=int, default=520, help="base image size" |
|
) |
|
parser.add_argument( |
|
"--crop-size", type=int, default=480, help="crop image size" |
|
) |
|
parser.add_argument( |
|
"--train-split", |
|
type=str, |
|
default="train", |
|
help="dataset train split (default: train)", |
|
) |
|
parser.add_argument( |
|
"--aux", action="store_true", default=False, help="Auxilary Loss" |
|
) |
|
parser.add_argument( |
|
"--se-loss", |
|
action="store_true", |
|
default=False, |
|
help="Semantic Encoding Loss SE-loss", |
|
) |
|
parser.add_argument( |
|
"--se-weight", type=float, default=0.2, help="SE-loss weight (default: 0.2)" |
|
) |
|
parser.add_argument( |
|
"--batch-size", |
|
type=int, |
|
default=16, |
|
metavar="N", |
|
help="input batch size for \ |
|
training (default: auto)", |
|
) |
|
parser.add_argument( |
|
"--test-batch-size", |
|
type=int, |
|
default=16, |
|
metavar="N", |
|
help="input batch size for \ |
|
testing (default: same as batch size)", |
|
) |
|
|
|
parser.add_argument( |
|
"--no-cuda", |
|
action="store_true", |
|
default=False, |
|
help="disables CUDA training", |
|
) |
|
parser.add_argument( |
|
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" |
|
) |
|
|
|
parser.add_argument( |
|
"--weights", type=str, default='', help="checkpoint to test" |
|
) |
|
|
|
parser.add_argument( |
|
"--eval", action="store_true", default=False, help="evaluating mIoU" |
|
) |
|
parser.add_argument( |
|
"--export", |
|
type=str, |
|
default=None, |
|
help="put the path to resuming file if needed", |
|
) |
|
parser.add_argument( |
|
"--acc-bn", |
|
action="store_true", |
|
default=False, |
|
help="Re-accumulate BN statistics", |
|
) |
|
parser.add_argument( |
|
"--test-val", |
|
action="store_true", |
|
default=False, |
|
help="generate masks on val set", |
|
) |
|
parser.add_argument( |
|
"--no-val", |
|
action="store_true", |
|
default=False, |
|
help="skip validation during training", |
|
) |
|
|
|
parser.add_argument( |
|
"--module", |
|
default='lseg', |
|
help="select model definition", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--data-path", type=str, default='../datasets/', help="path to test image folder" |
|
) |
|
|
|
parser.add_argument( |
|
"--no-scaleinv", |
|
dest="scale_inv", |
|
default=True, |
|
action="store_false", |
|
help="turn off scaleinv layers", |
|
) |
|
|
|
parser.add_argument( |
|
"--widehead", default=False, action="store_true", help="wider output head" |
|
) |
|
|
|
parser.add_argument( |
|
"--widehead_hr", |
|
default=False, |
|
action="store_true", |
|
help="wider output head", |
|
) |
|
parser.add_argument( |
|
"--ignore_index", |
|
type=int, |
|
default=-1, |
|
help="numeric value of ignore label in gt", |
|
) |
|
|
|
parser.add_argument( |
|
"--label_src", |
|
type=str, |
|
default="default", |
|
help="how to get the labels", |
|
) |
|
|
|
parser.add_argument( |
|
"--arch_option", |
|
type=int, |
|
default=0, |
|
help="which kind of architecture to be used", |
|
) |
|
|
|
parser.add_argument( |
|
"--block_depth", |
|
type=int, |
|
default=0, |
|
help="how many blocks should be used", |
|
) |
|
|
|
parser.add_argument( |
|
"--activation", |
|
choices=['lrelu', 'tanh'], |
|
default="lrelu", |
|
help="use which activation to activate the block", |
|
) |
|
|
|
self.parser = parser |
|
|
|
def parse(self): |
|
args = self.parser.parse_args(args=[]) |
|
args.cuda = not args.no_cuda and torch.cuda.is_available() |
|
print(args) |
|
return args |
|
|
|
args = Options().parse() |
|
|
|
torch.manual_seed(args.seed) |
|
args.test_batch_size = 1 |
|
alpha=0.5 |
|
|
|
args.scale_inv = False |
|
args.widehead = True |
|
args.dataset = 'ade20k' |
|
args.backbone = 'clip_vitl16_384' |
|
args.weights = 'checkpoints/demo_e200.ckpt' |
|
args.ignore_index = 255 |
|
|
|
module = LSegModule.load_from_checkpoint( |
|
checkpoint_path=args.weights, |
|
data_path=args.data_path, |
|
dataset=args.dataset, |
|
backbone=args.backbone, |
|
aux=args.aux, |
|
num_features=256, |
|
aux_weight=0, |
|
se_loss=False, |
|
se_weight=0, |
|
base_lr=0, |
|
batch_size=1, |
|
max_epochs=0, |
|
ignore_index=args.ignore_index, |
|
dropout=0.0, |
|
scale_inv=args.scale_inv, |
|
augment=False, |
|
no_batchnorm=False, |
|
widehead=args.widehead, |
|
widehead_hr=args.widehead_hr, |
|
map_locatin="cpu", |
|
arch_option=0, |
|
block_depth=0, |
|
activation='lrelu', |
|
) |
|
|
|
input_transform = module.val_transform |
|
|
|
|
|
loader_kwargs = ( |
|
{"num_workers": args.workers, "pin_memory": True} if args.cuda else {} |
|
) |
|
|
|
|
|
if isinstance(module.net, BaseNet): |
|
model = module.net |
|
else: |
|
model = module |
|
|
|
model = model.eval() |
|
model = model.cpu() |
|
scales = ( |
|
[0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] |
|
if args.dataset == "citys" |
|
else [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] |
|
) |
|
|
|
model.mean = [0.5, 0.5, 0.5] |
|
model.std = [0.5, 0.5, 0.5] |
|
evaluator = LSeg_MultiEvalModule( |
|
model, scales=scales, flip=True |
|
).cuda() |
|
evaluator.eval() |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
|
transforms.Resize([360,480]), |
|
] |
|
) |
|
|
|
return evaluator, transform |
|
|
|
""" |
|
# LSeg Demo |
|
""" |
|
lseg_model, lseg_transform = load_model() |
|
|
|
|
|
uploaded_file = gr.inputs.Image(type='pil') |
|
input_labels = st.text_input("Input labels", value="dog, grass, other") |
|
gr.outputs.Label(type="confidences",num_top_classes=5) |
|
st.write("The labels are", input_labels) |
|
|
|
image = Image.open(uploaded_file) |
|
pimage = lseg_transform(np.array(image)).unsqueeze(0) |
|
|
|
labels = [] |
|
for label in input_labels.split(","): |
|
labels.append(label.strip()) |
|
|
|
with torch.no_grad(): |
|
outputs = lseg_model.parallel_forward(pimage, labels) |
|
|
|
predicts = [ |
|
torch.max(output, 1)[1].cpu().numpy() |
|
for output in outputs |
|
] |
|
|
|
image = pimage[0].permute(1,2,0) |
|
image = image * 0.5 + 0.5 |
|
image = Image.fromarray(np.uint8(255*image)).convert("RGBA") |
|
|
|
pred = predicts[0] |
|
new_palette = get_new_pallete(len(labels)) |
|
mask, patches = get_new_mask_pallete(pred, new_palette, out_label_flag=True, labels=labels) |
|
seg = mask.convert("RGBA") |
|
|
|
fig = plt.figure() |
|
plt.subplot(121) |
|
plt.imshow(image) |
|
plt.axis('off') |
|
|
|
plt.subplot(122) |
|
plt.imshow(seg) |
|
plt.legend(handles=patches, loc='upper right', bbox_to_anchor=(1.3, 1), prop={'size': 5}) |
|
plt.axis('off') |
|
|
|
plt.tight_layout() |
|
|
|
|
|
st.pyplot(fig) |
|
|
|
title = "LSeg" |
|
|
|
description = "Gradio demo for LSeg for semantic segmentation. To use it, simply upload your image, or click one of the examples to load them, then add any label set" |
|
|
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.03546' target='_blank'>Language-driven Semantic Segmentation</a> | <a href='hhttps://github.com/isl-org/lang-seg' target='_blank'>Github Repo</a></p>" |
|
|
|
examples = ['test.jpeg'] |
|
|
|
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, analytics_enabled=False, examples=examples).launch(enable_queue=True) |