gavinyuan
update: app.py
d25d54e
raw
history blame contribute delete
No virus
3.96 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from third_party.arcface.iresnet import iresnet50, iresnet100
class MouthNet(nn.Module):
def __init__(self,
bisenet: nn.Module,
feature_dim: int = 64,
crop_param: tuple = (0, 0, 112, 112),
iresnet_pretrained: bool = False,
):
super(MouthNet, self).__init__()
crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0]) # (H,W)
fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7))
self.bisenet = bisenet
self.backbone = iresnet50(
pretrained=iresnet_pretrained,
num_features=feature_dim,
fp16=False,
fc_scale=fc_scale,
)
self.register_buffer(
name="vgg_mean",
tensor=torch.tensor([[[0.485]], [[0.456]], [[0.406]]], requires_grad=False),
)
self.register_buffer(
name="vgg_std",
tensor=torch.tensor([[[0.229]], [[0.224]], [[0.225]]], requires_grad=False),
)
def forward(self, x):
# with torch.no_grad():
# x_mouth_mask = self.get_any_mask(x, par=[11, 12, 13], normalized=True) # (B,1,H,W), in [0,1], 1:chosed
x_mouth_mask = 1
x_mouth = x * x_mouth_mask # (B,3,112,112)
mouth_feature = self.backbone(x_mouth)
return mouth_feature
def get_any_mask(self, img, par, normalized=False):
# [0, 'background', 1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye',
# 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 11 'mouth', 12 'u_lip',
# 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
ori_size = img.size()[-1]
with torch.no_grad():
img = F.interpolate(img, size=512, mode="nearest", )
if not normalized:
img = img * 0.5 + 0.5
img = img.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
out = self.bisenet(img)[0]
parsing = out.softmax(1).argmax(1)
mask = torch.zeros_like(parsing)
for p in par:
mask = mask + ((parsing == p).float())
mask = mask.unsqueeze(1)
mask = F.interpolate(mask, size=ori_size, mode="bilinear", align_corners=True)
return mask
def save_backbone(self, path: str):
torch.save(self.backbone.state_dict(), path)
def load_backbone(self, path: str):
self.backbone.load_state_dict(torch.load(path, map_location='cpu'))
if __name__ == "__main__":
from third_party.bisenet.bisenet import BiSeNet
bisenet = BiSeNet(19)
bisenet.load_state_dict(
torch.load(
"/gavin/datasets/hanbang/79999_iter.pth",
map_location="cpu",
)
)
bisenet.eval()
bisenet.requires_grad_(False)
crop_param = (28, 56, 84, 112)
import numpy as np
img = np.random.randn(112, 112, 3) * 225
from PIL import Image
img = Image.fromarray(img.astype(np.uint8))
img = img.crop(crop_param)
from torchvision import transforms
trans = transforms.ToTensor()
img = trans(img).unsqueeze(0)
img = img.repeat(3, 1, 1, 1)
print(img.shape)
net = MouthNet(
bisenet=bisenet,
feature_dim=64,
crop_param=crop_param
)
mouth_feat = net(img)
print(mouth_feat.shape)
import thop
crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0]) # (H,W)
fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7))
backbone = iresnet100(
pretrained=False,
num_features=64,
fp16=False,
# fc_scale=fc_scale,
)
flops, params = thop.profile(backbone, inputs=(torch.randn(1, 3, 112, 112),), verbose=False)
print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9))