Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn as nn | |
| from timm.models.layers import trunc_normal_, DropPath | |
| from mamba_ssm import Mamba | |
| from torch.nn import functional as F | |
| class ChannelAttention3D(nn.Module): | |
| def __init__(self, in_channels, reduction): | |
| super(ChannelAttention3D, self).__init__() | |
| self.avg_pool = nn.AdaptiveAvgPool3d(1) | |
| self.max_pool = nn.AdaptiveMaxPool3d(1) | |
| self.fc = nn.Sequential( | |
| nn.Conv3d(in_channels, in_channels // reduction, 1, bias=False), | |
| nn.ReLU(), | |
| nn.Conv3d(in_channels // reduction, in_channels, 1, bias=False) | |
| ) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| avg_out = self.fc(self.avg_pool(x)) | |
| max_out = self.fc(self.max_pool(x)) | |
| out = avg_out + max_out | |
| attention = self.sigmoid(out) | |
| return x*attention | |
| class LateralConnection(nn.Module): | |
| def __init__(self, fast_channels=32, slow_channels=64): | |
| super(LateralConnection, self).__init__() | |
| self.conv = nn.Sequential( | |
| nn.Conv3d(fast_channels, slow_channels, [3, 1, 1], stride=[2, 1, 1], padding=[1,0,0]), | |
| nn.BatchNorm3d(64), | |
| nn.ReLU(), | |
| ) | |
| def forward(self, slow_path, fast_path): | |
| fast_path = self.conv(fast_path) | |
| return fast_path + slow_path | |
| class CDC_T(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, | |
| padding=1, dilation=1, groups=1, bias=False, theta=0.2): | |
| super(CDC_T, self).__init__() | |
| self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, | |
| dilation=dilation, groups=groups, bias=bias) | |
| self.theta = theta | |
| def forward(self, x): | |
| out_normal = self.conv(x) | |
| if math.fabs(self.theta - 0.0) < 1e-8: | |
| return out_normal | |
| else: | |
| [C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape | |
| # only CD works on temporal kernel size>1 | |
| if self.conv.weight.shape[2] > 1: | |
| kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum( | |
| 2).sum(2) | |
| kernel_diff = kernel_diff[:, :, None, None, None] | |
| out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, | |
| padding=0, dilation=self.conv.dilation, groups=self.conv.groups) | |
| return out_normal - self.theta * out_diff | |
| else: | |
| return out_normal | |
| class MambaLayer(nn.Module): | |
| def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2, channel_token = False): | |
| super(MambaLayer, self).__init__() | |
| self.dim = dim | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| drop_path = 0 | |
| self.mamba = Mamba( | |
| d_model=dim, # Model dimension d_model | |
| d_state=d_state, # SSM state expansion factor | |
| d_conv=d_conv, # Local convolution width | |
| expand=expand, # Block expansion factor | |
| bimamba=True, | |
| ) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if isinstance(m, nn.Linear) and m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| elif isinstance(m, nn.Conv2d): | |
| fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
| fan_out //= m.groups | |
| m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| def forward_patch_token(self, x): | |
| B, C, nf, H, W = x.shape | |
| B, d_model = x.shape[:2] | |
| assert d_model == self.dim | |
| n_tokens = x.shape[2:].numel() | |
| img_dims = x.shape[2:] | |
| x_flat = x.reshape(B, d_model, n_tokens).transpose(-1, -2) | |
| x_norm = self.norm1(x_flat) | |
| x_mamba = self.mamba(x_norm) | |
| x_out = self.norm2(x_flat + self.drop_path(x_mamba)) | |
| out = x_out.transpose(-1, -2).reshape(B, d_model, *img_dims) | |
| return out | |
| def forward(self, x): | |
| if x.dtype == torch.float16 or x.dtype == torch.bfloat16: | |
| x = x.type(torch.float32) | |
| out = self.forward_patch_token(x) | |
| return out | |
| def conv_block(in_channels, out_channels, kernel_size, stride, padding, bn=True, activation='relu'): | |
| layers = [nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)] | |
| if bn: | |
| layers.append(nn.BatchNorm3d(out_channels)) | |
| if activation == 'relu': | |
| layers.append(nn.ReLU(inplace=True)) | |
| elif activation == 'elu': | |
| layers.append(nn.ELU(inplace=True)) | |
| return nn.Sequential(*layers) | |
| class PhysMamba(nn.Module): | |
| def __init__(self, theta=0.5, drop_rate1=0.25, drop_rate2=0.5, frames=128): | |
| super(PhysMamba, self).__init__() | |
| self.ConvBlock1 = conv_block(3, 16, [1, 5, 5], stride=1, padding=[0, 2, 2]) | |
| self.ConvBlock2 = conv_block(16, 32, [3, 3, 3], stride=1, padding=1) | |
| self.ConvBlock3 = conv_block(32, 64, [3, 3, 3], stride=1, padding=1) | |
| self.ConvBlock4 = conv_block(64, 64, [4, 1, 1], stride=[4, 1, 1], padding=0) | |
| self.ConvBlock5 = conv_block(64, 32, [2, 1, 1], stride=[2, 1, 1], padding=0) | |
| self.ConvBlock6 = conv_block(32, 32, [3, 1, 1], stride=1, padding=[1, 0, 0], activation='elu') | |
| # Temporal Difference Mamba Blocks | |
| # Slow Stream | |
| self.Block1 = self._build_block(64, theta) | |
| self.Block2 = self._build_block(64, theta) | |
| self.Block3 = self._build_block(64, theta) | |
| # Fast Stream | |
| self.Block4 = self._build_block(32, theta) | |
| self.Block5 = self._build_block(32, theta) | |
| self.Block6 = self._build_block(32, theta) | |
| # Upsampling | |
| self.upsample1 = nn.Sequential( | |
| nn.Upsample(scale_factor=(2,1,1)), | |
| nn.Conv3d(64, 64, [3, 1, 1], stride=1, padding=(1,0,0)), | |
| nn.BatchNorm3d(64), | |
| nn.ELU(), | |
| ) | |
| self.upsample2 = nn.Sequential( | |
| nn.Upsample(scale_factor=(2,1,1)), | |
| nn.Conv3d(96, 48, [3, 1, 1], stride=1, padding=(1,0,0)), | |
| nn.BatchNorm3d(48), | |
| nn.ELU(), | |
| ) | |
| self.ConvBlockLast = nn.Conv3d(48, 1, [1, 1, 1], stride=1, padding=0) | |
| self.MaxpoolSpa = nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)) | |
| self.MaxpoolSpaTem = nn.MaxPool3d((2, 2, 2), stride=2) | |
| self.fuse_1 = LateralConnection(fast_channels=32, slow_channels=64) | |
| self.fuse_2 = LateralConnection(fast_channels=32, slow_channels=64) | |
| self.drop_1 = nn.Dropout(drop_rate1) | |
| self.drop_2 = nn.Dropout(drop_rate1) | |
| self.drop_3 = nn.Dropout(drop_rate2) | |
| self.drop_4 = nn.Dropout(drop_rate2) | |
| self.drop_5 = nn.Dropout(drop_rate2) | |
| self.drop_6 = nn.Dropout(drop_rate2) | |
| self.poolspa = nn.AdaptiveAvgPool3d((frames, 1, 1)) | |
| def _build_block(self, channels, theta): | |
| return nn.Sequential( | |
| CDC_T(channels, channels, theta=theta), | |
| nn.BatchNorm3d(channels), | |
| nn.ReLU(), | |
| MambaLayer(dim=channels), | |
| ChannelAttention3D(in_channels=channels, reduction=2), | |
| ) | |
| def forward(self, x): | |
| [batch, channel, length, width, height] = x.shape | |
| x = self.ConvBlock1(x) | |
| x = self.MaxpoolSpa(x) | |
| x = self.ConvBlock2(x) | |
| x = self.ConvBlock3(x) | |
| x = self.MaxpoolSpa(x) | |
| # Process streams | |
| s_x = self.ConvBlock4(x) # Slow stream | |
| f_x = self.ConvBlock5(x) # Fast stream | |
| # First set of blocks and fusion | |
| s_x1 = self.Block1(s_x) | |
| s_x1 = self.MaxpoolSpa(s_x1) | |
| s_x1 = self.drop_1(s_x1) | |
| f_x1 = self.Block4(f_x) | |
| f_x1 = self.MaxpoolSpa(f_x1) | |
| f_x1 = self.drop_2(f_x1) | |
| s_x1 = self.fuse_1(s_x1,f_x1) # LateralConnection | |
| # Second set of blocks and fusion | |
| s_x2 = self.Block2(s_x1) | |
| s_x2 = self.MaxpoolSpa(s_x2) | |
| s_x2 = self.drop_3(s_x2) | |
| f_x2 = self.Block5(f_x1) | |
| f_x2 = self.MaxpoolSpa(f_x2) | |
| f_x2 = self.drop_4(f_x2) | |
| s_x2 = self.fuse_2(s_x2,f_x2) # LateralConnection | |
| # Third blocks and upsampling | |
| s_x3 = self.Block3(s_x2) | |
| s_x3 = self.upsample1(s_x3) | |
| s_x3 = self.drop_5(s_x3) | |
| f_x3 = self.Block6(f_x2) | |
| f_x3 = self.ConvBlock6(f_x3) | |
| f_x3 = self.drop_6(f_x3) | |
| # Final fusion and upsampling | |
| x_fusion = torch.cat((f_x3, s_x3), dim=1) | |
| x_final = self.upsample2(x_fusion) | |
| x_final = self.poolspa(x_final) | |
| x_final = self.ConvBlockLast(x_final) | |
| rPPG = x_final.view(-1, length) | |
| return rPPG | |