import torch import torch.nn as nn import torch.nn.functional as F import math from third_party.arcface.iresnet import iresnet50, iresnet100 class MouthNet(nn.Module): def __init__(self, bisenet: nn.Module, feature_dim: int = 64, crop_param: tuple = (0, 0, 112, 112), iresnet_pretrained: bool = False, ): super(MouthNet, self).__init__() crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0]) # (H,W) fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7)) self.bisenet = bisenet self.backbone = iresnet50( pretrained=iresnet_pretrained, num_features=feature_dim, fp16=False, fc_scale=fc_scale, ) self.register_buffer( name="vgg_mean", tensor=torch.tensor([[[0.485]], [[0.456]], [[0.406]]], requires_grad=False), ) self.register_buffer( name="vgg_std", tensor=torch.tensor([[[0.229]], [[0.224]], [[0.225]]], requires_grad=False), ) def forward(self, x): # with torch.no_grad(): # x_mouth_mask = self.get_any_mask(x, par=[11, 12, 13], normalized=True) # (B,1,H,W), in [0,1], 1:chosed x_mouth_mask = 1 x_mouth = x * x_mouth_mask # (B,3,112,112) mouth_feature = self.backbone(x_mouth) return mouth_feature def get_any_mask(self, img, par, normalized=False): # [0, 'background', 1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', # 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 11 'mouth', 12 'u_lip', # 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat'] ori_size = img.size()[-1] with torch.no_grad(): img = F.interpolate(img, size=512, mode="nearest", ) if not normalized: img = img * 0.5 + 0.5 img = img.sub(self.vgg_mean.detach()).div(self.vgg_std.detach()) out = self.bisenet(img)[0] parsing = out.softmax(1).argmax(1) mask = torch.zeros_like(parsing) for p in par: mask = mask + ((parsing == p).float()) mask = mask.unsqueeze(1) mask = F.interpolate(mask, size=ori_size, mode="bilinear", align_corners=True) return mask def save_backbone(self, path: str): torch.save(self.backbone.state_dict(), path) def load_backbone(self, path: str): self.backbone.load_state_dict(torch.load(path, map_location='cpu')) if __name__ == "__main__": from third_party.bisenet.bisenet import BiSeNet bisenet = BiSeNet(19) bisenet.load_state_dict( torch.load( "/gavin/datasets/hanbang/79999_iter.pth", map_location="cpu", ) ) bisenet.eval() bisenet.requires_grad_(False) crop_param = (28, 56, 84, 112) import numpy as np img = np.random.randn(112, 112, 3) * 225 from PIL import Image img = Image.fromarray(img.astype(np.uint8)) img = img.crop(crop_param) from torchvision import transforms trans = transforms.ToTensor() img = trans(img).unsqueeze(0) img = img.repeat(3, 1, 1, 1) print(img.shape) net = MouthNet( bisenet=bisenet, feature_dim=64, crop_param=crop_param ) mouth_feat = net(img) print(mouth_feat.shape) import thop crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0]) # (H,W) fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7)) backbone = iresnet100( pretrained=False, num_features=64, fp16=False, # fc_scale=fc_scale, ) flops, params = thop.profile(backbone, inputs=(torch.randn(1, 3, 112, 112),), verbose=False) print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9))