|
|
''' |
|
|
Exploring Temporal Coherence for More General Video Face Forgery Detection @ ICCV'2021 |
|
|
Copyright (c) Xiamen University and its affiliates. |
|
|
Modified by Yinglin Zheng from https://github.com/yinglinzheng/FTCN |
|
|
''' |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from .time_transformer import TimeTransformer |
|
|
from .clip import clip |
|
|
|
|
|
class RandomPatchPool(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
b, c, t, h, w = x.shape |
|
|
x = x.reshape(b, c, t, h * w) |
|
|
if self.training and my_cfg.model.transformer.random_select: |
|
|
while True: |
|
|
idx = random.randint(0, h * w - 1) |
|
|
i = idx // h |
|
|
j = idx % h |
|
|
if j == 0 or i == h - 1 or j == h - 1: |
|
|
continue |
|
|
else: |
|
|
break |
|
|
else: |
|
|
idx = h * w // 2 |
|
|
x = x[..., idx] |
|
|
return x |
|
|
|
|
|
|
|
|
def valid_idx(idx, h): |
|
|
i = idx // h |
|
|
j = idx % h |
|
|
if j == 0 or i == h - 1 or j == h - 1: |
|
|
return False |
|
|
else: |
|
|
return True |
|
|
|
|
|
|
|
|
class RandomAvgPool(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
b, c, t, h, w = x.shape |
|
|
x = x.reshape(b, c, t, h * w) |
|
|
candidates = list(range(h * w)) |
|
|
candidates = [idx for idx in candidates if valid_idx(idx, h)] |
|
|
max_k = len(candidates) |
|
|
if self.training and my_cfg.model.transformer.random_select: |
|
|
k = my_cfg.model.transformer.k |
|
|
else: |
|
|
k = max_k |
|
|
candidates = random.sample(candidates, k) |
|
|
x = x[..., candidates].mean(-1) |
|
|
return x |
|
|
|
|
|
class TransformerHead(nn.Module): |
|
|
def __init__(self, spatial_size=7, time_size=8, in_channels=2048): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
patch_type = 'time' |
|
|
if patch_type == "time": |
|
|
self.pool = nn.AvgPool3d((1, spatial_size, spatial_size)) |
|
|
self.num_patches = time_size |
|
|
elif patch_type == "spatial": |
|
|
self.pool = nn.AvgPool3d((time_size, 1, 1)) |
|
|
self.num_patches = spatial_size ** 2 |
|
|
elif patch_type == "random": |
|
|
self.pool = RandomPatchPool() |
|
|
self.num_patches = time_size |
|
|
elif patch_type == "random_avg": |
|
|
self.pool = RandomAvgPool() |
|
|
self.num_patches = time_size |
|
|
elif patch_type == "all": |
|
|
self.pool = nn.Identity() |
|
|
self.num_patches = time_size * spatial_size * spatial_size |
|
|
else: |
|
|
raise NotImplementedError(patch_type) |
|
|
|
|
|
self.dim = -1 |
|
|
if self.dim == -1: |
|
|
self.dim = in_channels |
|
|
|
|
|
self.in_channels = in_channels |
|
|
|
|
|
if self.dim != self.in_channels: |
|
|
self.fc = nn.Linear(self.in_channels, self.dim) |
|
|
|
|
|
default_params = dict( |
|
|
dim=self.dim, depth=6, heads=16, mlp_dim=2048, dropout=0.1, emb_dropout=0.1, |
|
|
) |
|
|
|
|
|
self.time_T = TimeTransformer( |
|
|
num_patches=self.num_patches, num_classes=1, **default_params |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
x = self.pool(x) |
|
|
x = x.reshape(-1, self.in_channels, self.num_patches) |
|
|
x = x.permute(0, 2, 1) |
|
|
if self.dim != self.in_channels: |
|
|
x = self.fc(x.reshape(-1, self.in_channels)) |
|
|
x = x.reshape(-1, self.num_patches, self.dim) |
|
|
x = self.time_T(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class ViT_B_FTCN(nn.Module): |
|
|
def __init__( |
|
|
self, channel_size=512, class_num=1 |
|
|
): |
|
|
super(ViT_B_FTCN, self).__init__() |
|
|
self.clip_model, preprocess = clip.load('ViT-B-16') |
|
|
self.clip_model = self.clip_model.float() |
|
|
|
|
|
self.head = TransformerHead(spatial_size=14, time_size=8, in_channels=512) |
|
|
|
|
|
def forward(self, x): |
|
|
b, t, _, h, w = x.shape |
|
|
images = x.view(b * t, 3, h, w) |
|
|
sequence_output = self.clip_model.encode_image(images) |
|
|
_, _, c = sequence_output.shape |
|
|
sequence_output = sequence_output.view(b, t, 14, 14, c) |
|
|
sequence_output = sequence_output.permute(0, 4, 1, 2, 3) |
|
|
|
|
|
res = self.head(sequence_output) |
|
|
|
|
|
return res |
|
|
if __name__ == '__main__': |
|
|
model = ViT_B_FTCN() |
|
|
model = model.cuda() |
|
|
dummy_input = torch.randn(4,8,3,224,224) |
|
|
dummy_input = dummy_input.cuda() |
|
|
model(dummy_input) |
|
|
|