File size: 3,963 Bytes
a104d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d25d54e
a104d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from third_party.arcface.iresnet import iresnet50, iresnet100

class MouthNet(nn.Module):
    def __init__(self,
                 bisenet: nn.Module,
                 feature_dim: int = 64,
                 crop_param: tuple = (0, 0, 112, 112),
                 iresnet_pretrained: bool = False,
                 ):
        super(MouthNet, self).__init__()

        crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0])  # (H,W)
        fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7))

        self.bisenet = bisenet
        self.backbone = iresnet50(
            pretrained=iresnet_pretrained,
            num_features=feature_dim,
            fp16=False,
            fc_scale=fc_scale,
        )

        self.register_buffer(
            name="vgg_mean",
            tensor=torch.tensor([[[0.485]], [[0.456]], [[0.406]]], requires_grad=False),
        )
        self.register_buffer(
            name="vgg_std",
            tensor=torch.tensor([[[0.229]], [[0.224]], [[0.225]]], requires_grad=False),
        )

    def forward(self, x):
        # with torch.no_grad():
        #     x_mouth_mask = self.get_any_mask(x, par=[11, 12, 13], normalized=True)  # (B,1,H,W), in [0,1], 1:chosed
        x_mouth_mask = 1
        x_mouth = x * x_mouth_mask  # (B,3,112,112)
        mouth_feature = self.backbone(x_mouth)
        return mouth_feature

    def get_any_mask(self, img, par, normalized=False):
        # [0, 'background', 1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye',
        # 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r',  10 'nose', 11 'mouth', 12 'u_lip',
        # 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
        ori_size = img.size()[-1]
        with torch.no_grad():
            img = F.interpolate(img, size=512, mode="nearest", )
            if not normalized:
                img = img * 0.5 + 0.5
                img = img.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
            out = self.bisenet(img)[0]
            parsing = out.softmax(1).argmax(1)
        mask = torch.zeros_like(parsing)
        for p in par:
            mask = mask + ((parsing == p).float())
        mask = mask.unsqueeze(1)
        mask = F.interpolate(mask, size=ori_size, mode="bilinear", align_corners=True)
        return mask

    def save_backbone(self, path: str):
        torch.save(self.backbone.state_dict(), path)

    def load_backbone(self, path: str):
        self.backbone.load_state_dict(torch.load(path, map_location='cpu'))


if __name__ == "__main__":
    from third_party.bisenet.bisenet import BiSeNet

    bisenet = BiSeNet(19)
    bisenet.load_state_dict(
        torch.load(
            "/gavin/datasets/hanbang/79999_iter.pth",
            map_location="cpu",
        )
    )
    bisenet.eval()
    bisenet.requires_grad_(False)

    crop_param = (28, 56, 84, 112)

    import numpy as np
    img = np.random.randn(112, 112, 3) * 225
    from PIL import Image
    img = Image.fromarray(img.astype(np.uint8))
    img = img.crop(crop_param)

    from torchvision import transforms
    trans = transforms.ToTensor()
    img = trans(img).unsqueeze(0)
    img = img.repeat(3, 1, 1, 1)
    print(img.shape)

    net = MouthNet(
        bisenet=bisenet,
        feature_dim=64,
        crop_param=crop_param
    )
    mouth_feat = net(img)
    print(mouth_feat.shape)

    import thop

    crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0])  # (H,W)
    fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7))
    backbone = iresnet100(
        pretrained=False,
        num_features=64,
        fp16=False,
        # fc_scale=fc_scale,
    )
    flops, params = thop.profile(backbone, inputs=(torch.randn(1, 3, 112, 112),), verbose=False)
    print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9))