Spaces:
Sleeping
Sleeping
| """ | |
| RhythmFormer:Extracting rPPG Signals Based on Hierarchical Temporal Periodic Transformer | |
| """ | |
| from typing import Optional | |
| import torch | |
| from torch import nn, Tensor, LongTensor | |
| from torch.nn import functional as F | |
| import math | |
| from typing import Tuple, Union | |
| from timm.models.layers import trunc_normal_, DropPath | |
| """ | |
| Adapted from here: https://github.com/rayleizhu/BiFormer | |
| """ | |
| import torch | |
| from torch import Tensor, LongTensor , nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple | |
| def _grid2seq(x:Tensor, region_size:Tuple[int], num_heads:int): | |
| """ | |
| Args: | |
| x: BCTHW tensor | |
| region size: int | |
| num_heads: number of attention heads | |
| Return: | |
| out: rearranged x, has a shape of (bs, nhead, nregion, reg_size, head_dim) | |
| region_t, region_h, region_w: number of regions per t/col/row | |
| """ | |
| B, C, T, H, W = x.size() | |
| region_t ,region_h, region_w = T//region_size[0], H//region_size[1], W//region_size[2] | |
| x = x.view(B, num_heads, C//num_heads, region_t, region_size[0],region_h, region_size[1], region_w, region_size[2]) | |
| x = torch.einsum('bmdtohpwq->bmthwopqd', x).flatten(2, 4).flatten(-4, -2) # (bs, nhead, nregion, reg_size, head_dim) | |
| return x, region_t, region_h, region_w | |
| def _seq2grid(x:Tensor, region_t:int, region_h:int, region_w:int, region_size:Tuple[int]): | |
| """ | |
| Args: | |
| x: (bs, nhead, nregion, reg_size^2, head_dim) | |
| Return: | |
| x: (bs, C, T, H, W) | |
| """ | |
| bs, nhead, nregion, reg_size_square, head_dim = x.size() | |
| x = x.view(bs, nhead, region_t, region_h, region_w, region_size[0], region_size[1], region_size[2], head_dim) | |
| x = torch.einsum('bmthwopqd->bmdtohpwq', x).reshape(bs, nhead*head_dim, | |
| region_t*region_size[0],region_h*region_size[1], region_w*region_size[2]) | |
| return x | |
| def video_regional_routing_attention_torch( | |
| query:Tensor, key:Tensor, value:Tensor, scale:float, | |
| region_graph:LongTensor, region_size:Tuple[int], | |
| kv_region_size:Optional[Tuple[int]]=None, | |
| auto_pad=False)->Tensor: | |
| """ | |
| Args: | |
| query, key, value: (B, C, T, H, W) tensor | |
| scale: the scale/temperature for dot product attention | |
| region_graph: (B, nhead, t_q*h_q*w_q, topk) tensor, topk <= t_k*h_k*w_k | |
| region_size: region/window size for queries, (rt, rh, rw) | |
| key_region_size: optional, if None, key_region_size=region_size | |
| Return: | |
| output: (B, C, T, H, W) tensor | |
| attn: (bs, nhead, q_nregion, reg_size, topk*kv_region_size) attention matrix | |
| """ | |
| kv_region_size = kv_region_size or region_size | |
| bs, nhead, q_nregion, topk = region_graph.size() | |
| # # Auto pad to deal with any input size | |
| # q_pad_b, q_pad_r, kv_pad_b, kv_pad_r = 0, 0, 0, 0 | |
| # if auto_pad: | |
| # _, _, Hq, Wq = query.size() | |
| # q_pad_b = (region_size[0] - Hq % region_size[0]) % region_size[0] | |
| # q_pad_r = (region_size[1] - Wq % region_size[1]) % region_size[1] | |
| # if (q_pad_b > 0 or q_pad_r > 0): | |
| # query = F.pad(query, (0, q_pad_r, 0, q_pad_b)) # zero padding | |
| # _, _, Hk, Wk = key.size() | |
| # kv_pad_b = (kv_region_size[0] - Hk % kv_region_size[0]) % kv_region_size[0] | |
| # kv_pad_r = (kv_region_size[1] - Wk % kv_region_size[1]) % kv_region_size[1] | |
| # if (kv_pad_r > 0 or kv_pad_b > 0): | |
| # key = F.pad(key, (0, kv_pad_r, 0, kv_pad_b)) # zero padding | |
| # value = F.pad(value, (0, kv_pad_r, 0, kv_pad_b)) # zero padding | |
| # to sequence format, i.e. (bs, nhead, nregion, reg_size, head_dim) | |
| query, q_region_t, q_region_h, q_region_w = _grid2seq(query, region_size=region_size, num_heads=nhead) | |
| key, _, _, _ = _grid2seq(key, region_size=kv_region_size, num_heads=nhead) | |
| value, _, _, _ = _grid2seq(value, region_size=kv_region_size, num_heads=nhead) | |
| # gather key and values. | |
| # torch.gather does not support broadcasting, hence we do it manually | |
| bs, nhead, kv_nregion, kv_region_size, head_dim = key.size() | |
| broadcasted_region_graph = region_graph.view(bs, nhead, q_nregion, topk, 1, 1).\ | |
| expand(-1, -1, -1, -1, kv_region_size, head_dim) | |
| key_g = torch.gather(key.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\ | |
| expand(-1, -1, query.size(2), -1, -1, -1), dim=3, | |
| index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim) | |
| value_g = torch.gather(value.view(bs, nhead, 1, kv_nregion, kv_region_size, head_dim).\ | |
| expand(-1, -1, query.size(2), -1, -1, -1), dim=3, | |
| index=broadcasted_region_graph) # (bs, nhead, q_nregion, topk, kv_region_size, head_dim) | |
| # token-to-token attention | |
| # (bs, nhead, q_nregion, reg_size, head_dim) @ (bs, nhead, q_nregion, head_dim, topk*kv_region_size) | |
| # -> (bs, nhead, q_nregion, reg_size, topk*kv_region_size) | |
| attn = (query * scale) @ key_g.flatten(-3, -2).transpose(-1, -2) | |
| attn = torch.softmax(attn, dim=-1) | |
| # (bs, nhead, q_nregion, reg_size, topk*kv_region_size) @ (bs, nhead, q_nregion, topk*kv_region_size, head_dim) | |
| # -> (bs, nhead, q_nregion, reg_size, head_dim) | |
| output = attn @ value_g.flatten(-3, -2) | |
| # to BCTHW format | |
| output = _seq2grid(output, region_t=q_region_t, region_h=q_region_h, region_w=q_region_w, region_size=region_size) | |
| # remove paddings if needed | |
| # if auto_pad and (q_pad_b > 0 or q_pad_r > 0): | |
| # output = output[:, :, :Hq, :Wq] | |
| return output, attn | |
| class CDC_T(nn.Module): | |
| """ | |
| The CDC_T Module is from here: https://github.com/ZitongYu/PhysFormer/model/transformer_layer.py | |
| """ | |
| def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, | |
| padding=1, dilation=1, groups=1, bias=False, theta=0.6): | |
| super(CDC_T, self).__init__() | |
| self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, | |
| dilation=dilation, groups=groups, bias=bias) | |
| self.theta = theta | |
| def forward(self, x): | |
| out_normal = self.conv(x) | |
| if math.fabs(self.theta - 0.0) < 1e-8: | |
| return out_normal | |
| else: | |
| # pdb.set_trace() | |
| [C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape | |
| # only CD works on temporal kernel size>1 | |
| if self.conv.weight.shape[2] > 1: | |
| kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum( | |
| 2).sum(2) | |
| kernel_diff = kernel_diff[:, :, None, None, None] | |
| out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, | |
| padding=0, dilation=self.conv.dilation, groups=self.conv.groups) | |
| return out_normal - self.theta * out_diff | |
| else: | |
| return out_normal | |
| class video_BRA(nn.Module): | |
| def __init__(self, dim, num_heads=8, t_patch=8, qk_scale=None, topk=4, side_dwconv=3, auto_pad=False, attn_backend='torch'): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_heads = num_heads | |
| assert self.dim % num_heads == 0, 'dim must be divisible by num_heads!' | |
| self.head_dim = self.dim // self.num_heads | |
| self.scale = qk_scale or self.dim ** -0.5 | |
| self.topk = topk | |
| self.t_patch = t_patch # frame of patch | |
| ################side_dwconv (i.e. LCE in Shunted Transformer)########### | |
| self.lepe = nn.Conv3d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \ | |
| lambda x: torch.zeros_like(x) | |
| ########################################## | |
| self.qkv_linear = nn.Conv3d(self.dim, 3*self.dim, kernel_size=1) | |
| self.output_linear = nn.Conv3d(self.dim, self.dim, kernel_size=1) | |
| self.proj_q = nn.Sequential( | |
| CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=0.2), | |
| nn.BatchNorm3d(dim), | |
| ) | |
| self.proj_k = nn.Sequential( | |
| CDC_T(dim, dim, 3, stride=1, padding=1, groups=1, bias=False, theta=0.2), | |
| nn.BatchNorm3d(dim), | |
| ) | |
| self.proj_v = nn.Sequential( | |
| nn.Conv3d(dim, dim, 1, stride=1, padding=0, groups=1, bias=False), | |
| ) | |
| if attn_backend == 'torch': | |
| self.attn_fn = video_regional_routing_attention_torch | |
| else: | |
| raise ValueError('CUDA implementation is not available yet. Please stay tuned.') | |
| def forward(self, x:Tensor): | |
| N, C, T, H, W = x.size() | |
| t_region = max(4 // self.t_patch , 1) | |
| region_size = (t_region, H//4 , W//4) | |
| # STEP 1: linear projection | |
| q , k , v = self.proj_q(x) , self.proj_k(x) ,self.proj_v(x) | |
| # STEP 2: pre attention | |
| q_r = F.avg_pool3d(q.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) | |
| k_r = F.avg_pool3d(k.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) # ncthw | |
| q_r:Tensor = q_r.permute(0, 2, 3, 4, 1).flatten(1, 3) # n(thw)c | |
| k_r:Tensor = k_r.flatten(2, 4) # nc(thw) | |
| a_r = q_r @ k_r # n(thw)(thw) | |
| _, idx_r = torch.topk(a_r, k=self.topk, dim=-1) # n(thw)k | |
| idx_r:LongTensor = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1) | |
| # STEP 3: refined attention | |
| output, attn_mat = self.attn_fn(query=q, key=k, value=v, scale=self.scale, | |
| region_graph=idx_r, region_size=region_size) | |
| output = output + self.lepe(v) # nctHW | |
| output = self.output_linear(output) # nctHW | |
| return output | |
| class video_BiFormerBlock(nn.Module): | |
| def __init__(self, dim, drop_path=0., num_heads=4, t_patch=1,qk_scale=None, topk=4, mlp_ratio=2, side_dwconv=5): | |
| super().__init__() | |
| self.t_patch = t_patch | |
| self.norm1 = nn.BatchNorm3d(dim) | |
| self.attn = video_BRA(dim=dim, num_heads=num_heads, t_patch=t_patch,qk_scale=qk_scale, topk=topk, side_dwconv=side_dwconv) | |
| self.norm2 = nn.BatchNorm3d(dim) | |
| self.mlp = nn.Sequential(nn.Conv3d(dim, int(mlp_ratio*dim), kernel_size=1), | |
| nn.BatchNorm3d(int(mlp_ratio*dim)), | |
| nn.GELU(), | |
| nn.Conv3d(int(mlp_ratio*dim), int(mlp_ratio*dim), 3, stride=1, padding=1), | |
| nn.BatchNorm3d(int(mlp_ratio*dim)), | |
| nn.GELU(), | |
| nn.Conv3d(int(mlp_ratio*dim), dim, kernel_size=1), | |
| nn.BatchNorm3d(dim), | |
| ) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| def forward(self, x): | |
| x = x + self.drop_path(self.attn(self.norm1(x))) | |
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |
| return x | |
| class Fusion_Stem(nn.Module): | |
| def __init__(self,apha=0.5,belta=0.5): | |
| super(Fusion_Stem, self).__init__() | |
| self.stem11 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2), | |
| nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) | |
| ) | |
| self.stem12 = nn.Sequential(nn.Conv2d(12, 64, kernel_size=5, stride=2, padding=2), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) | |
| ) | |
| self.stem21 =nn.Sequential( | |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.stem22 =nn.Sequential( | |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.apha = apha | |
| self.belta = belta | |
| def forward(self, x): | |
| """Definition of Fusion_Stem. | |
| Args: | |
| x [N,D,C,H,W] | |
| Returns: | |
| fusion_x [N*D,C,H/4,W/4] | |
| """ | |
| N, D, C, H, W = x.shape | |
| x1 = torch.cat([x[:,:1,:,:,:],x[:,:1,:,:,:],x[:,:D-2,:,:,:]],1) | |
| x2 = torch.cat([x[:,:1,:,:,:],x[:,:D-1,:,:,:]],1) | |
| x3 = x | |
| x4 = torch.cat([x[:,1:,:,:,:],x[:,D-1:,:,:,:]],1) | |
| x5 = torch.cat([x[:,2:,:,:,:],x[:,D-1:,:,:,:],x[:,D-1:,:,:,:]],1) | |
| x_diff = self.stem12(torch.cat([x2-x1,x3-x2,x4-x3,x5-x4],2).view(N * D, 12, H, W)) | |
| x3 = x3.contiguous().view(N * D, C, H, W) | |
| x = self.stem11(x3) | |
| #fusion layer1 | |
| x_path1 = self.apha*x + self.belta*x_diff | |
| x_path1 = self.stem21(x_path1) | |
| #fusion layer2 | |
| x_path2 = self.stem22(x_diff) | |
| x = self.apha*x_path1 + self.belta*x_path2 | |
| return x | |
| class TPT_Block(nn.Module): | |
| def __init__(self, dim, depth, num_heads, t_patch, topk, | |
| mlp_ratio=4., drop_path=0., side_dwconv=5): | |
| super().__init__() | |
| self.dim = dim | |
| self.depth = depth | |
| ############ downsample layers & upsample layers ##################### | |
| self.downsample_layers = nn.ModuleList() | |
| self.upsample_layers = nn.ModuleList() | |
| self.layer_n = int(math.log(t_patch,2)) | |
| for i in range(self.layer_n): | |
| downsample_layer = nn.Sequential( | |
| nn.BatchNorm3d(dim), | |
| nn.Conv3d(dim , dim , kernel_size=(2, 1, 1), stride=(2, 1, 1)), | |
| ) | |
| self.downsample_layers.append(downsample_layer) | |
| upsample_layer = nn.Sequential( | |
| nn.Upsample(scale_factor=(2, 1, 1)), | |
| nn.Conv3d(dim , dim , [3, 1, 1], stride=1, padding=(1, 0, 0)), | |
| nn.BatchNorm3d(dim), | |
| nn.ELU(), | |
| ) | |
| self.upsample_layers.append(upsample_layer) | |
| ###################################################################### | |
| self.blocks = nn.ModuleList([ | |
| video_BiFormerBlock( | |
| dim=dim, | |
| drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, | |
| num_heads=num_heads, | |
| t_patch=t_patch, | |
| topk=topk, | |
| mlp_ratio=mlp_ratio, | |
| side_dwconv=side_dwconv, | |
| ) | |
| for i in range(depth) | |
| ]) | |
| def forward(self, x:torch.Tensor): | |
| """Definition of TPT_Block. | |
| Args: | |
| x [N,C,D,H,W] | |
| Returns: | |
| x [N,C,D,H,W] | |
| """ | |
| for i in range(self.layer_n) : | |
| x = self.downsample_layers[i](x) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| for i in range(self.layer_n) : | |
| x = self.upsample_layers[i](x) | |
| return x | |
| class RhythmFormer(nn.Module): | |
| def __init__( | |
| self, | |
| name: Optional[str] = None, | |
| pretrained: bool = False, | |
| dim: int = 64, frame: int = 160, | |
| image_size: Optional[int] = (160,128,128), | |
| in_chans=64, head_dim=16, | |
| stage_n = 3, | |
| embed_dim=[64, 64, 64], mlp_ratios=[1.5, 1.5, 1.5], | |
| depth=[2, 2, 2], | |
| t_patchs:Union[int, Tuple[int]]=(2, 4, 8), | |
| topks:Union[int, Tuple[int]]=(40, 40, 40), | |
| side_dwconv:int=3, | |
| drop_path_rate=0., | |
| use_checkpoint_stages=[], | |
| ): | |
| super().__init__() | |
| self.image_size = image_size | |
| self.frame = frame | |
| self.dim = dim | |
| self.stage_n = stage_n | |
| self.Fusion_Stem = Fusion_Stem() | |
| self.patch_embedding = nn.Conv3d(in_chans,embed_dim[0], kernel_size=(1, 4, 4), stride=(1, 4, 4)) | |
| self.ConvBlockLast = nn.Conv1d(embed_dim[-1], 1, kernel_size=1,stride=1, padding=0) | |
| ########################################################################## | |
| self.stages = nn.ModuleList() | |
| nheads= [dim // head_dim for dim in embed_dim] | |
| dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] | |
| for i in range(stage_n): | |
| stage = TPT_Block(dim=embed_dim[i], | |
| depth=depth[i], | |
| num_heads=nheads[i], | |
| mlp_ratio=mlp_ratios[i], | |
| drop_path=dp_rates[sum(depth[:i]):sum(depth[:i+1])], | |
| t_patch=t_patchs[i], topk=topks[i], side_dwconv=side_dwconv | |
| ) | |
| self.stages.append(stage) | |
| ########################################################################## | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward(self, x): | |
| N, D, C, H, W = x.shape | |
| x = self.Fusion_Stem(x) #[N*D 64 H/4 W/4] | |
| x = x.view(N,D,64,H//4,W//4).permute(0,2,1,3,4) | |
| x = self.patch_embedding(x) #[N 64 D 8 8] | |
| for i in range(3): | |
| x = self.stages[i](x) #[N 64 D 8 8] | |
| features_last = torch.mean(x,3) #[N, 64, D, 8] | |
| features_last = torch.mean(features_last,3) #[N, 64, D] | |
| rPPG = self.ConvBlockLast(features_last) #[N, 1, D] | |
| rPPG = rPPG.squeeze(1) | |
| return rPPG | |