IDM-VTON
update IDM-VTON Demo
938e515
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
from typing import List
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from detectron2.config import configurable
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from detectron2.utils.registry import Registry
__all__ = ["FastRCNNConvFCHead", "build_box_head", "ROI_BOX_HEAD_REGISTRY"]
ROI_BOX_HEAD_REGISTRY = Registry("ROI_BOX_HEAD")
ROI_BOX_HEAD_REGISTRY.__doc__ = """
Registry for box heads, which make box predictions from per-region features.
The registered object will be called with `obj(cfg, input_shape)`.
"""
# To get torchscript support, we make the head a subclass of `nn.Sequential`.
# Therefore, to add new layers in this head class, please make sure they are
# added in the order they will be used in forward().
@ROI_BOX_HEAD_REGISTRY.register()
class FastRCNNConvFCHead(nn.Sequential):
"""
A head with several 3x3 conv layers (each followed by norm & relu) and then
several fc layers (each followed by relu).
"""
@configurable
def __init__(
self, input_shape: ShapeSpec, *, conv_dims: List[int], fc_dims: List[int], conv_norm=""
):
"""
NOTE: this interface is experimental.
Args:
input_shape (ShapeSpec): shape of the input feature.
conv_dims (list[int]): the output dimensions of the conv layers
fc_dims (list[int]): the output dimensions of the fc layers
conv_norm (str or callable): normalization for the conv layers.
See :func:`detectron2.layers.get_norm` for supported types.
"""
super().__init__()
assert len(conv_dims) + len(fc_dims) > 0
self._output_size = (input_shape.channels, input_shape.height, input_shape.width)
self.conv_norm_relus = []
for k, conv_dim in enumerate(conv_dims):
conv = Conv2d(
self._output_size[0],
conv_dim,
kernel_size=3,
padding=1,
bias=not conv_norm,
norm=get_norm(conv_norm, conv_dim),
activation=nn.ReLU(),
)
self.add_module("conv{}".format(k + 1), conv)
self.conv_norm_relus.append(conv)
self._output_size = (conv_dim, self._output_size[1], self._output_size[2])
self.fcs = []
for k, fc_dim in enumerate(fc_dims):
if k == 0:
self.add_module("flatten", nn.Flatten())
fc = nn.Linear(int(np.prod(self._output_size)), fc_dim)
self.add_module("fc{}".format(k + 1), fc)
self.add_module("fc_relu{}".format(k + 1), nn.ReLU())
self.fcs.append(fc)
self._output_size = fc_dim
for layer in self.conv_norm_relus:
weight_init.c2_msra_fill(layer)
for layer in self.fcs:
weight_init.c2_xavier_fill(layer)
@classmethod
def from_config(cls, cfg, input_shape):
num_conv = cfg.MODEL.ROI_BOX_HEAD.NUM_CONV
conv_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_DIM
num_fc = cfg.MODEL.ROI_BOX_HEAD.NUM_FC
fc_dim = cfg.MODEL.ROI_BOX_HEAD.FC_DIM
return {
"input_shape": input_shape,
"conv_dims": [conv_dim] * num_conv,
"fc_dims": [fc_dim] * num_fc,
"conv_norm": cfg.MODEL.ROI_BOX_HEAD.NORM,
}
def forward(self, x):
for layer in self:
x = layer(x)
return x
@property
@torch.jit.unused
def output_shape(self):
"""
Returns:
ShapeSpec: the output feature shape
"""
o = self._output_size
if isinstance(o, int):
return ShapeSpec(channels=o)
else:
return ShapeSpec(channels=o[0], height=o[1], width=o[2])
def build_box_head(cfg, input_shape):
"""
Build a box head defined by `cfg.MODEL.ROI_BOX_HEAD.NAME`.
"""
name = cfg.MODEL.ROI_BOX_HEAD.NAME
return ROI_BOX_HEAD_REGISTRY.get(name)(cfg, input_shape)