Kalpit
feat: Add model files with LFS
d39b279
'''
Exploring Temporal Coherence for More General Video Face Forgery Detection @ ICCV'2021
Copyright (c) Xiamen University and its affiliates.
Modified by Yinglin Zheng from https://github.com/yinglinzheng/FTCN
'''
import torch
from torch import nn
from .time_transformer import TimeTransformer
from .clip import clip
class RandomPatchPool(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
# batch,channel,16,7x7
b, c, t, h, w = x.shape
x = x.reshape(b, c, t, h * w)
if self.training and my_cfg.model.transformer.random_select:
while True:
idx = random.randint(0, h * w - 1)
i = idx // h
j = idx % h
if j == 0 or i == h - 1 or j == h - 1:
continue
else:
break
else:
idx = h * w // 2
x = x[..., idx]
return x
def valid_idx(idx, h):
i = idx // h
j = idx % h
if j == 0 or i == h - 1 or j == h - 1:
return False
else:
return True
class RandomAvgPool(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
# batch,channel,16,7x7
b, c, t, h, w = x.shape
x = x.reshape(b, c, t, h * w)
candidates = list(range(h * w))
candidates = [idx for idx in candidates if valid_idx(idx, h)]
max_k = len(candidates)
if self.training and my_cfg.model.transformer.random_select:
k = my_cfg.model.transformer.k
else:
k = max_k
candidates = random.sample(candidates, k)
x = x[..., candidates].mean(-1)
return x
class TransformerHead(nn.Module):
def __init__(self, spatial_size=7, time_size=8, in_channels=2048):
super().__init__()
# if my_cfg.model.inco.no_time_pool:
# time_size = time_size * 2
patch_type = 'time'
if patch_type == "time":
self.pool = nn.AvgPool3d((1, spatial_size, spatial_size))
self.num_patches = time_size
elif patch_type == "spatial":
self.pool = nn.AvgPool3d((time_size, 1, 1))
self.num_patches = spatial_size ** 2
elif patch_type == "random":
self.pool = RandomPatchPool()
self.num_patches = time_size
elif patch_type == "random_avg":
self.pool = RandomAvgPool()
self.num_patches = time_size
elif patch_type == "all":
self.pool = nn.Identity()
self.num_patches = time_size * spatial_size * spatial_size
else:
raise NotImplementedError(patch_type)
self.dim = -1
if self.dim == -1:
self.dim = in_channels
self.in_channels = in_channels
if self.dim != self.in_channels:
self.fc = nn.Linear(self.in_channels, self.dim)
default_params = dict(
dim=self.dim, depth=6, heads=16, mlp_dim=2048, dropout=0.1, emb_dropout=0.1,
)
self.time_T = TimeTransformer(
num_patches=self.num_patches, num_classes=1, **default_params
)
def forward(self, x):
x = self.pool(x)
x = x.reshape(-1, self.in_channels, self.num_patches)
x = x.permute(0, 2, 1)
if self.dim != self.in_channels:
x = self.fc(x.reshape(-1, self.in_channels))
x = x.reshape(-1, self.num_patches, self.dim)
x = self.time_T(x)
return x
class ViT_B_FTCN(nn.Module):
def __init__(
self, channel_size=512, class_num=1
):
super(ViT_B_FTCN, self).__init__()
self.clip_model, preprocess = clip.load('ViT-B-16')
self.clip_model = self.clip_model.float()
self.head = TransformerHead(spatial_size=14, time_size=8, in_channels=512)
def forward(self, x):
b, t, _, h, w = x.shape
images = x.view(b * t, 3, h, w)
sequence_output = self.clip_model.encode_image(images)
_, _, c = sequence_output.shape
sequence_output = sequence_output.view(b, t, 14, 14, c)
sequence_output = sequence_output.permute(0, 4, 1, 2, 3)
res = self.head(sequence_output)
return res
if __name__ == '__main__':
model = ViT_B_FTCN()
model = model.cuda()
dummy_input = torch.randn(4,8,3,224,224)
dummy_input = dummy_input.cuda()
model(dummy_input)