ov-seg / open_vocab_seg /utils /post_process_utils.py
liangfeng
add ovseg
583456e
raw history blame
No virus
1.99 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import torch
from torch.nn import functional as F
import numpy as np
try:
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import (
unary_from_softmax,
unary_from_labels,
create_pairwise_bilateral,
create_pairwise_gaussian,
)
except:
dcrf = None
def dense_crf_post_process(
logits,
image,
n_labels=None,
max_iters=5,
pos_xy_std=(3, 3),
pos_w=3,
bi_xy_std=(80, 80),
bi_rgb_std=(13, 13, 13),
bi_w=10,
):
"""
logits : [C,H,W]
image : [3,H,W]
"""
if dcrf is None:
raise FileNotFoundError(
"pydensecrf is required to perform dense crf inference."
)
if isinstance(logits, torch.Tensor):
logits = F.softmax(logits, dim=0).detach().cpu().numpy()
U = unary_from_softmax(logits)
n_labels = logits.shape[0]
elif logits.ndim == 3:
U = unary_from_softmax(logits)
n_labels = logits.shape[0]
else:
assert n_labels is not None
U = unary_from_labels(logits, n_labels, zero_unsure=False)
d = dcrf.DenseCRF2D(image.shape[1], image.shape[0], n_labels)
d.setUnaryEnergy(U)
# This adds the color-independent term, features are the locations only.
d.addPairwiseGaussian(
sxy=pos_xy_std,
compat=pos_w,
kernel=dcrf.DIAG_KERNEL,
normalization=dcrf.NORMALIZE_SYMMETRIC,
)
# This adds the color-dependent term, i.e. features are (x,y,r,g,b).
d.addPairwiseBilateral(
sxy=bi_xy_std,
srgb=bi_rgb_std,
rgbim=image,
compat=bi_w,
kernel=dcrf.DIAG_KERNEL,
normalization=dcrf.NORMALIZE_SYMMETRIC,
)
# Run five inference steps.
logits = d.inference(max_iters)
logits = np.asarray(logits).reshape((n_labels, image.shape[0], image.shape[1]))
return torch.from_numpy(logits)