| | |
| |
|
| | """ |
| | Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume. |
| | """ |
| |
|
| | import torch |
| | from torch import nn |
| | from .util import SameBlock2d, DownBlock2d, ResBlock3d |
| |
|
| |
|
| | class AppearanceFeatureExtractor(nn.Module): |
| |
|
| | def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks): |
| | super(AppearanceFeatureExtractor, self).__init__() |
| | self.image_channel = image_channel |
| | self.block_expansion = block_expansion |
| | self.num_down_blocks = num_down_blocks |
| | self.max_features = max_features |
| | self.reshape_channel = reshape_channel |
| | self.reshape_depth = reshape_depth |
| |
|
| | self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1)) |
| |
|
| | down_blocks = [] |
| | for i in range(num_down_blocks): |
| | in_features = min(max_features, block_expansion * (2 ** i)) |
| | out_features = min(max_features, block_expansion * (2 ** (i + 1))) |
| | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) |
| | self.down_blocks = nn.ModuleList(down_blocks) |
| |
|
| | self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1) |
| |
|
| | self.resblocks_3d = torch.nn.Sequential() |
| | for i in range(num_resblocks): |
| | self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1)) |
| |
|
| | def forward(self, source_image): |
| | out = self.first(source_image) |
| |
|
| | for i in range(len(self.down_blocks)): |
| | out = self.down_blocks[i](out) |
| | out = self.second(out) |
| | bs, c, h, w = out.shape |
| |
|
| | f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) |
| | f_s = self.resblocks_3d(f_s) |
| | return f_s |
| |
|