| """ |
| TimeSformer-based Video Anomaly Detection Model |
| For deepfake detection via reconstruction |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| class Factorized3DConv(nn.Module): |
| def __init__(self, in_c, out_c, k=3, s=1, p=1): |
| super().__init__() |
| self.spatial = nn.Conv3d(in_c, out_c, (1,k,k), (1,s,s), (0,p,p), bias=False) |
| self.temporal = nn.Conv3d(out_c, out_c, (k,1,1), (1,1,1), (p,0,0), bias=False) |
| self.bn1 = nn.BatchNorm3d(out_c) |
| self.bn2 = nn.BatchNorm3d(out_c) |
| self.relu = nn.ReLU(inplace=True) |
| |
| def forward(self, x): |
| return self.relu(self.bn2(self.temporal(self.relu(self.bn1(self.spatial(x)))))) |
|
|
| class OpticalFlowEstimator(nn.Module): |
| def __init__(self, in_c=3): |
| super().__init__() |
| self.conv1 = nn.Conv2d(in_c*2, 64, 7, 2, 3) |
| self.conv2 = nn.Conv2d(64, 128, 5, 2, 2) |
| self.conv3 = nn.Conv2d(128, 256, 3, 2, 1) |
| self.flow_head = nn.Conv2d(256, 2, 1) |
| |
| def forward(self, f1, f2): |
| x = torch.cat([f1, f2], 1) |
| x = F.relu(self.conv1(x)) |
| x = F.relu(self.conv2(x)) |
| x = F.relu(self.conv3(x)) |
| return F.interpolate(self.flow_head(x), f1.shape[2:], mode='bilinear', align_corners=False) |
|
|
| class PatchEmbedding3D(nn.Module): |
| def __init__(self, img_sz=224, p_sz=16, in_c=3, emb=768, t_sz=2): |
| super().__init__() |
| self.proj = nn.Conv3d(in_c, emb, (t_sz, p_sz, p_sz), (t_sz, p_sz, p_sz)) |
| |
| def forward(self, x): |
| x = self.proj(x) |
| B, E, T, H, W = x.shape |
| return rearrange(x, 'b e t h w -> b (t h w) e'), T*H*W |
|
|
| class MultiHeadAttention3D(nn.Module): |
| def __init__(self, dim, heads=12, drop=0.): |
| super().__init__() |
| self.heads = heads |
| self.scale = (dim // heads) ** -0.5 |
| self.qkv = nn.Linear(dim, dim * 3) |
| self.proj = nn.Linear(dim, dim) |
| self.drop = nn.Dropout(drop) |
| |
| def forward(self, x): |
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = self.drop(attn.softmax(-1)) |
| return self.proj((attn @ v).transpose(1, 2).reshape(B, N, C)) |
|
|
| class TransformerBlock3D(nn.Module): |
| def __init__(self, dim, heads, mlp_r=4., drop=0.): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim) |
| self.attn = MultiHeadAttention3D(dim, heads, drop) |
| self.norm2 = nn.LayerNorm(dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, int(dim * mlp_r)), |
| nn.GELU(), |
| nn.Dropout(drop), |
| nn.Linear(int(dim * mlp_r), dim), |
| nn.Dropout(drop) |
| ) |
| |
| def forward(self, x): |
| x = x + self.attn(self.norm1(x)) |
| return x + self.mlp(self.norm2(x)) |
|
|
| class EnhancedTimeSformer(nn.Module): |
| def __init__(self, img_sz=224, p_sz=16, in_c=3, n_fr=16, emb=768, depth=12, heads=12): |
| super().__init__() |
| self.img_sz = img_sz |
| self.p_sz = p_sz |
| self.stem = nn.Sequential( |
| Factorized3DConv(in_c, 64, 7, 2, 3), |
| Factorized3DConv(64, 128, 3, 2, 1) |
| ) |
| self.flow_est = OpticalFlowEstimator(in_c) |
| self.flow_enc = nn.Sequential( |
| nn.Conv2d(2, 64, 3, 1, 1), |
| nn.ReLU(), |
| nn.Conv2d(64, 128, 3, 1, 1) |
| ) |
| self.patch_emb = PatchEmbedding3D(img_sz // 4, p_sz // 4, 128, emb, 2) |
| self.cls_tok = nn.Parameter(torch.zeros(1, 1, emb)) |
| self.pos_emb = nn.Parameter(torch.zeros(1, 2048, emb)) |
| self.blocks = nn.ModuleList([ |
| TransformerBlock3D(emb, heads, 4., 0.1) for _ in range(depth) |
| ]) |
| self.norm = nn.LayerNorm(emb) |
| self.recon_grid = (img_sz // 4) // (p_sz // 4) |
| self.dec = nn.ModuleDict({ |
| 'frame': nn.Sequential( |
| nn.Linear(emb, emb * 2), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(emb * 2, self.recon_grid ** 2 * in_c) |
| ), |
| 'flow': nn.Sequential( |
| nn.Linear(emb, emb), |
| nn.GELU(), |
| nn.Linear(emb, self.recon_grid ** 2 * 2) |
| ) |
| }) |
| nn.init.trunc_normal_(self.pos_emb, std=0.02) |
| nn.init.trunc_normal_(self.cls_tok, std=0.02) |
| |
| def forward(self, x): |
| B, C, T, H, W = x.shape |
| flows = [self.flow_est(x[:, :, t], x[:, :, t+1]) for t in range(T-1)] |
| if flows: |
| flows = torch.stack(flows, 2) |
| flows = rearrange( |
| self.flow_enc(rearrange(flows, 'b c t h w -> (b t) c h w')), |
| '(b t) c h w -> b c t h w', |
| b=B |
| ) |
| else: |
| flows = None |
| x = self.stem(x) |
| if flows is not None: |
| flows = F.interpolate( |
| rearrange(flows, 'b c t h w -> (b t) c h w'), |
| size=(x.shape[3], x.shape[4]), |
| mode='bilinear', |
| align_corners=False |
| ) |
| flows = rearrange(flows, '(b t) c h w -> b c t h w', b=B) |
| if flows.shape[2] < x.shape[2]: |
| flows = F.pad(flows, (0, 0, 0, 0, 0, x.shape[2] - flows.shape[2])) |
| x = x + 0.1 * flows |
| x, n_p = self.patch_emb(x) |
| x = torch.cat([self.cls_tok.expand(B, -1, -1), x], 1) |
| if x.shape[1] > self.pos_emb.shape[1]: |
| pe = F.interpolate( |
| self.pos_emb.permute(0, 2, 1), |
| x.shape[1], |
| mode='linear', |
| align_corners=False |
| ).permute(0, 2, 1) |
| else: |
| pe = self.pos_emb[:, :x.shape[1]] |
| x = x + pe |
| for blk in self.blocks: |
| x = blk(x) |
| x = self.norm(x) |
| mid = n_p // 2 |
| tok = x[:, mid:mid+1] |
| fr = rearrange( |
| self.dec['frame'](tok), |
| 'b 1 (p1 p2 c) -> b c p1 p2', |
| p1=self.recon_grid, |
| p2=self.recon_grid, |
| c=C |
| ) |
| fl = rearrange( |
| self.dec['flow'](tok), |
| 'b 1 (p1 p2 c) -> b c p1 p2', |
| p1=self.recon_grid, |
| p2=self.recon_grid, |
| c=2 |
| ) |
| fr = F.interpolate(fr, (H, W), mode='bilinear', align_corners=False) |
| fl = F.interpolate(fl, (H, W), mode='bilinear', align_corners=False) |
| return fr, fl |
|
|
| def create_model(): |
| """Factory function to create the model""" |
| return EnhancedTimeSformer( |
| img_sz=224, |
| p_sz=16, |
| in_c=3, |
| n_fr=16, |
| emb=768, |
| depth=12, |
| heads=12 |
| ) |
|
|