hytian2@gmail.com
update
eae1cca
import torch
from torch import nn
from torch.nn import functional as F
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
self.act = nn.ReLU()
def forward(self, x):
out = self.conv_block(x)
return self.act(out)
class Conv2d_res(nn.Module):
# TensorRT does not support 'if' statement, thus we create independent Conv2d_res for residual block
def __init__(self, cin, cout, kernel_size, stride, padding, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
self.act = nn.ReLU()
def forward(self, x):
out = self.conv_block(x)
out += x
return self.act(out)
class Conv2dTranspose(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
nn.BatchNorm2d(cout),
)
self.act = nn.ReLU()
def forward(self, x):
out = self.conv_block(x)
return self.act(out)
class FETE_model(nn.Module):
def __init__(self):
super(FETE_model, self).__init__()
self.face_encoder_blocks = nn.ModuleList(
[
nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=2, padding=3)), # 256,256 -> 128,128
nn.Sequential(
Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 64,64
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
),
nn.Sequential(
Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 32,32
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
),
nn.Sequential(
Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 16,16
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
),
nn.Sequential(
Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 8,8
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
),
nn.Sequential(
Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 4,4
Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1),
),
nn.Sequential(
Conv2d(512, 512, kernel_size=3, stride=2, padding=0), # 1, 1
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
),
]
)
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
)
self.pose_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1),
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1),
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
Conv2d(256, 512, kernel_size=3, stride=2, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
)
self.emotion_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=7, stride=1, padding=1),
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1),
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1),
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
Conv2d(256, 512, kernel_size=3, stride=2, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
)
self.blink_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1),
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
Conv2d(64, 128, kernel_size=3, stride=(1, 2), padding=1),
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1),
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
Conv2d(256, 512, kernel_size=1, stride=(1, 2), padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
)
self.face_decoder_blocks = nn.ModuleList(
[
nn.Sequential(
Conv2d(2048, 512, kernel_size=1, stride=1, padding=0),
),
nn.Sequential(
Conv2dTranspose(1024, 512, kernel_size=4, stride=1, padding=0), # 4,4
Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1),
),
nn.Sequential(
Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1),
Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), # 8,8
Self_Attention(512, 512),
),
nn.Sequential(
Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d_res(384, 384, kernel_size=3, stride=1, padding=1),
Conv2d_res(384, 384, kernel_size=3, stride=1, padding=1), # 16, 16
Self_Attention(384, 384),
),
nn.Sequential(
Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), # 32, 32
Self_Attention(256, 256),
),
nn.Sequential(
Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
), # 64, 64
nn.Sequential(
Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
),
]
) # 128,128
# self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
# nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
# nn.Sigmoid())
self.output_block = nn.Sequential(
Conv2dTranspose(80, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
nn.Sigmoid(),
)
def forward(
self,
face_sequences,
audio_sequences,
pose_sequences,
emotion_sequences,
blink_sequences,
):
# audio_sequences = (B, T, 1, 80, 16)
B = audio_sequences.size(0)
# disabled for inference
# input_dim_size = len(face_sequences.size())
# if input_dim_size > 4:
# audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
# pose_sequences = torch.cat([pose_sequences[:, i] for i in range(pose_sequences.size(1))], dim=0)
# emotion_sequences = torch.cat([emotion_sequences[:, i] for i in range(emotion_sequences.size(1))], dim=0)
# blink_sequences = torch.cat([blink_sequences[:, i] for i in range(blink_sequences.size(1))], dim=0)
# face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
# print(audio_sequences.size(), face_sequences.size(), pose_sequences.size(), emotion_sequences.size())
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
pose_embedding = self.pose_encoder(pose_sequences) # B, 512, 1, 1
emotion_embedding = self.emotion_encoder(emotion_sequences) # B, 512, 1, 1
blink_embedding = self.blink_encoder(blink_sequences) # B, 512, 1, 1
inputs_embedding = torch.cat((audio_embedding, pose_embedding, emotion_embedding, blink_embedding), dim=1) # B, 1536, 1, 1
# print(audio_embedding.size(), pose_embedding.size(), emotion_embedding.size(), inputs_embedding.size())
feats = []
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
# print(x.shape)
feats.append(x)
x = inputs_embedding
for f in self.face_decoder_blocks:
x = f(x)
# print(x.shape)
# try:
x = torch.cat((x, feats[-1]), dim=1)
# except Exception as e:
# print(x.size())
# print(feats[-1].size())
# raise e
feats.pop()
x = self.output_block(x)
# if input_dim_size > 4:
# x = torch.split(x, B, dim=0) # [(B, C, H, W)]
# outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
# else:
outputs = x
return outputs
class Self_Attention(nn.Module):
"""
Source-Reference Attention Layer
"""
def __init__(self, in_planes_s, in_planes_r):
"""
Parameters
----------
in_planes_s: int
Number of input source feature vector channels.
in_planes_r: int
Number of input reference feature vector channels.
"""
super(Self_Attention, self).__init__()
self.query_conv = nn.Conv2d(in_channels=in_planes_s, out_channels=in_planes_s // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, source):
source = source.float() if isinstance(source, torch.cuda.HalfTensor) else source
reference = source
"""
Parameters
----------
source : torch.Tensor
Source feature maps (B x Cs x Ts x Hs x Ws)
reference : torch.Tensor
Reference feature maps (B x Cr x Tr x Hr x Wr )
Returns :
torch.Tensor
Source-reference attention value added to the input source features
torch.Tensor
Attention map (B x Ns x Nt) (Ns=Ts*Hs*Ws, Nr=Tr*Hr*Wr)
"""
s_batchsize, sC, sH, sW = source.size()
r_batchsize, rC, rH, rW = reference.size()
proj_query = self.query_conv(source).view(s_batchsize, -1, sH * sW).permute(0, 2, 1)
proj_key = self.key_conv(reference).view(r_batchsize, -1, rW * rH)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(reference).view(r_batchsize, -1, rH * rW)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(s_batchsize, sC, sH, sW)
out = self.gamma * out + source
return out.half() if isinstance(source, torch.cuda.FloatTensor) else out