IDM-VTON
update IDM-VTON Demo
938e515
raw
history blame
No virus
10.3 kB
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np
import torch
from detectron2.layers import ShapeSpec, cat, interpolate
from detectron2.modeling import ROI_HEADS_REGISTRY, StandardROIHeads
from detectron2.modeling.roi_heads.mask_head import (
build_mask_head,
mask_rcnn_inference,
mask_rcnn_loss,
)
from detectron2.modeling.roi_heads.roi_heads import select_foreground_proposals
from .point_features import (
generate_regular_grid_point_coords,
get_uncertain_point_coords_on_grid,
get_uncertain_point_coords_with_randomness,
point_sample,
point_sample_fine_grained_features,
)
from .point_head import build_point_head, roi_mask_point_loss
def calculate_uncertainty(logits, classes):
"""
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
foreground class in `classes`.
Args:
logits (Tensor): A tensor of shape (R, C, ...) or (R, 1, ...) for class-specific or
class-agnostic, where R is the total number of predicted masks in all images and C is
the number of foreground classes. The values are logits.
classes (list): A list of length R that contains either predicted of ground truth class
for eash predicted mask.
Returns:
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
if logits.shape[1] == 1:
gt_class_logits = logits.clone()
else:
gt_class_logits = logits[
torch.arange(logits.shape[0], device=logits.device), classes
].unsqueeze(1)
return -(torch.abs(gt_class_logits))
@ROI_HEADS_REGISTRY.register()
class PointRendROIHeads(StandardROIHeads):
"""
The RoI heads class for PointRend instance segmentation models.
In this class we redefine the mask head of `StandardROIHeads` leaving all other heads intact.
To avoid namespace conflict with other heads we use names starting from `mask_` for all
variables that correspond to the mask head in the class's namespace.
"""
def __init__(self, cfg, input_shape):
# TODO use explicit args style
super().__init__(cfg, input_shape)
self._init_mask_head(cfg, input_shape)
def _init_mask_head(self, cfg, input_shape):
# fmt: off
self.mask_on = cfg.MODEL.MASK_ON
if not self.mask_on:
return
self.mask_coarse_in_features = cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES
self.mask_coarse_side_size = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION
self._feature_scales = {k: 1.0 / v.stride for k, v in input_shape.items()}
# fmt: on
in_channels = np.sum([input_shape[f].channels for f in self.mask_coarse_in_features])
self.mask_coarse_head = build_mask_head(
cfg,
ShapeSpec(
channels=in_channels,
width=self.mask_coarse_side_size,
height=self.mask_coarse_side_size,
),
)
self._init_point_head(cfg, input_shape)
def _init_point_head(self, cfg, input_shape):
# fmt: off
self.mask_point_on = cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON
if not self.mask_point_on:
return
assert cfg.MODEL.ROI_HEADS.NUM_CLASSES == cfg.MODEL.POINT_HEAD.NUM_CLASSES
self.mask_point_in_features = cfg.MODEL.POINT_HEAD.IN_FEATURES
self.mask_point_train_num_points = cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS
self.mask_point_oversample_ratio = cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO
self.mask_point_importance_sample_ratio = cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO
# next two parameters are use in the adaptive subdivions inference procedure
self.mask_point_subdivision_steps = cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS
self.mask_point_subdivision_num_points = cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS
# fmt: on
in_channels = np.sum([input_shape[f].channels for f in self.mask_point_in_features])
self.mask_point_head = build_point_head(
cfg, ShapeSpec(channels=in_channels, width=1, height=1)
)
def _forward_mask(self, features, instances):
"""
Forward logic of the mask prediction branch.
Args:
features (dict[str, Tensor]): #level input features for mask prediction
instances (list[Instances]): the per-image instances to train/predict masks.
In training, they can be the proposals.
In inference, they can be the predicted boxes.
Returns:
In training, a dict of losses.
In inference, update `instances` with new fields "pred_masks" and return it.
"""
if not self.mask_on:
return {} if self.training else instances
if self.training:
proposals, _ = select_foreground_proposals(instances, self.num_classes)
proposal_boxes = [x.proposal_boxes for x in proposals]
mask_coarse_logits = self._forward_mask_coarse(features, proposal_boxes)
losses = {"loss_mask": mask_rcnn_loss(mask_coarse_logits, proposals)}
losses.update(self._forward_mask_point(features, mask_coarse_logits, proposals))
return losses
else:
pred_boxes = [x.pred_boxes for x in instances]
mask_coarse_logits = self._forward_mask_coarse(features, pred_boxes)
mask_logits = self._forward_mask_point(features, mask_coarse_logits, instances)
mask_rcnn_inference(mask_logits, instances)
return instances
def _forward_mask_coarse(self, features, boxes):
"""
Forward logic of the coarse mask head.
"""
point_coords = generate_regular_grid_point_coords(
np.sum(len(x) for x in boxes), self.mask_coarse_side_size, boxes[0].device
)
mask_coarse_features_list = [features[k] for k in self.mask_coarse_in_features]
features_scales = [self._feature_scales[k] for k in self.mask_coarse_in_features]
# For regular grids of points, this function is equivalent to `len(features_list)' calls
# of `ROIAlign` (with `SAMPLING_RATIO=2`), and concat the results.
mask_features, _ = point_sample_fine_grained_features(
mask_coarse_features_list, features_scales, boxes, point_coords
)
return self.mask_coarse_head(mask_features)
def _forward_mask_point(self, features, mask_coarse_logits, instances):
"""
Forward logic of the mask point head.
"""
if not self.mask_point_on:
return {} if self.training else mask_coarse_logits
mask_features_list = [features[k] for k in self.mask_point_in_features]
features_scales = [self._feature_scales[k] for k in self.mask_point_in_features]
if self.training:
proposal_boxes = [x.proposal_boxes for x in instances]
gt_classes = cat([x.gt_classes for x in instances])
with torch.no_grad():
point_coords = get_uncertain_point_coords_with_randomness(
mask_coarse_logits,
lambda logits: calculate_uncertainty(logits, gt_classes),
self.mask_point_train_num_points,
self.mask_point_oversample_ratio,
self.mask_point_importance_sample_ratio,
)
fine_grained_features, point_coords_wrt_image = point_sample_fine_grained_features(
mask_features_list, features_scales, proposal_boxes, point_coords
)
coarse_features = point_sample(mask_coarse_logits, point_coords, align_corners=False)
point_logits = self.mask_point_head(fine_grained_features, coarse_features)
return {
"loss_mask_point": roi_mask_point_loss(
point_logits, instances, point_coords_wrt_image
)
}
else:
pred_boxes = [x.pred_boxes for x in instances]
pred_classes = cat([x.pred_classes for x in instances])
# The subdivision code will fail with the empty list of boxes
if len(pred_classes) == 0:
return mask_coarse_logits
mask_logits = mask_coarse_logits.clone()
for subdivions_step in range(self.mask_point_subdivision_steps):
mask_logits = interpolate(
mask_logits, scale_factor=2, mode="bilinear", align_corners=False
)
# If `mask_point_subdivision_num_points` is larger or equal to the
# resolution of the next step, then we can skip this step
H, W = mask_logits.shape[-2:]
if (
self.mask_point_subdivision_num_points >= 4 * H * W
and subdivions_step < self.mask_point_subdivision_steps - 1
):
continue
uncertainty_map = calculate_uncertainty(mask_logits, pred_classes)
point_indices, point_coords = get_uncertain_point_coords_on_grid(
uncertainty_map, self.mask_point_subdivision_num_points
)
fine_grained_features, _ = point_sample_fine_grained_features(
mask_features_list, features_scales, pred_boxes, point_coords
)
coarse_features = point_sample(
mask_coarse_logits, point_coords, align_corners=False
)
point_logits = self.mask_point_head(fine_grained_features, coarse_features)
# put mask point predictions to the right places on the upsampled grid.
R, C, H, W = mask_logits.shape
point_indices = point_indices.unsqueeze(1).expand(-1, C, -1)
mask_logits = (
mask_logits.reshape(R, C, H * W)
.scatter_(2, point_indices, point_logits)
.view(R, C, H, W)
)
return mask_logits