swetchareddytukkani
Initial commit with PhysMamba rPPG application
1c6711c
import math
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_, DropPath
from mamba_ssm import Mamba
from torch.nn import functional as F
class ChannelAttention3D(nn.Module):
def __init__(self, in_channels, reduction):
super(ChannelAttention3D, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool3d(1)
self.max_pool = nn.AdaptiveMaxPool3d(1)
self.fc = nn.Sequential(
nn.Conv3d(in_channels, in_channels // reduction, 1, bias=False),
nn.ReLU(),
nn.Conv3d(in_channels // reduction, in_channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
attention = self.sigmoid(out)
return x*attention
class LateralConnection(nn.Module):
def __init__(self, fast_channels=32, slow_channels=64):
super(LateralConnection, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(fast_channels, slow_channels, [3, 1, 1], stride=[2, 1, 1], padding=[1,0,0]),
nn.BatchNorm3d(64),
nn.ReLU(),
)
def forward(self, slow_path, fast_path):
fast_path = self.conv(fast_path)
return fast_path + slow_path
class CDC_T(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=0.2):
super(CDC_T, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def forward(self, x):
out_normal = self.conv(x)
if math.fabs(self.theta - 0.0) < 1e-8:
return out_normal
else:
[C_out, C_in, t, kernel_size, kernel_size] = self.conv.weight.shape
# only CD works on temporal kernel size>1
if self.conv.weight.shape[2] > 1:
kernel_diff = self.conv.weight[:, :, 0, :, :].sum(2).sum(2) + self.conv.weight[:, :, 2, :, :].sum(
2).sum(2)
kernel_diff = kernel_diff[:, :, None, None, None]
out_diff = F.conv3d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride,
padding=0, dilation=self.conv.dilation, groups=self.conv.groups)
return out_normal - self.theta * out_diff
else:
return out_normal
class MambaLayer(nn.Module):
def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2, channel_token = False):
super(MambaLayer, self).__init__()
self.dim = dim
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
drop_path = 0
self.mamba = Mamba(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=d_conv, # Local convolution width
expand=expand, # Block expansion factor
bimamba=True,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward_patch_token(self, x):
B, C, nf, H, W = x.shape
B, d_model = x.shape[:2]
assert d_model == self.dim
n_tokens = x.shape[2:].numel()
img_dims = x.shape[2:]
x_flat = x.reshape(B, d_model, n_tokens).transpose(-1, -2)
x_norm = self.norm1(x_flat)
x_mamba = self.mamba(x_norm)
x_out = self.norm2(x_flat + self.drop_path(x_mamba))
out = x_out.transpose(-1, -2).reshape(B, d_model, *img_dims)
return out
def forward(self, x):
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.type(torch.float32)
out = self.forward_patch_token(x)
return out
def conv_block(in_channels, out_channels, kernel_size, stride, padding, bn=True, activation='relu'):
layers = [nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)]
if bn:
layers.append(nn.BatchNorm3d(out_channels))
if activation == 'relu':
layers.append(nn.ReLU(inplace=True))
elif activation == 'elu':
layers.append(nn.ELU(inplace=True))
return nn.Sequential(*layers)
class PhysMamba(nn.Module):
def __init__(self, theta=0.5, drop_rate1=0.25, drop_rate2=0.5, frames=128):
super(PhysMamba, self).__init__()
self.ConvBlock1 = conv_block(3, 16, [1, 5, 5], stride=1, padding=[0, 2, 2])
self.ConvBlock2 = conv_block(16, 32, [3, 3, 3], stride=1, padding=1)
self.ConvBlock3 = conv_block(32, 64, [3, 3, 3], stride=1, padding=1)
self.ConvBlock4 = conv_block(64, 64, [4, 1, 1], stride=[4, 1, 1], padding=0)
self.ConvBlock5 = conv_block(64, 32, [2, 1, 1], stride=[2, 1, 1], padding=0)
self.ConvBlock6 = conv_block(32, 32, [3, 1, 1], stride=1, padding=[1, 0, 0], activation='elu')
# Temporal Difference Mamba Blocks
# Slow Stream
self.Block1 = self._build_block(64, theta)
self.Block2 = self._build_block(64, theta)
self.Block3 = self._build_block(64, theta)
# Fast Stream
self.Block4 = self._build_block(32, theta)
self.Block5 = self._build_block(32, theta)
self.Block6 = self._build_block(32, theta)
# Upsampling
self.upsample1 = nn.Sequential(
nn.Upsample(scale_factor=(2,1,1)),
nn.Conv3d(64, 64, [3, 1, 1], stride=1, padding=(1,0,0)),
nn.BatchNorm3d(64),
nn.ELU(),
)
self.upsample2 = nn.Sequential(
nn.Upsample(scale_factor=(2,1,1)),
nn.Conv3d(96, 48, [3, 1, 1], stride=1, padding=(1,0,0)),
nn.BatchNorm3d(48),
nn.ELU(),
)
self.ConvBlockLast = nn.Conv3d(48, 1, [1, 1, 1], stride=1, padding=0)
self.MaxpoolSpa = nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2))
self.MaxpoolSpaTem = nn.MaxPool3d((2, 2, 2), stride=2)
self.fuse_1 = LateralConnection(fast_channels=32, slow_channels=64)
self.fuse_2 = LateralConnection(fast_channels=32, slow_channels=64)
self.drop_1 = nn.Dropout(drop_rate1)
self.drop_2 = nn.Dropout(drop_rate1)
self.drop_3 = nn.Dropout(drop_rate2)
self.drop_4 = nn.Dropout(drop_rate2)
self.drop_5 = nn.Dropout(drop_rate2)
self.drop_6 = nn.Dropout(drop_rate2)
self.poolspa = nn.AdaptiveAvgPool3d((frames, 1, 1))
def _build_block(self, channels, theta):
return nn.Sequential(
CDC_T(channels, channels, theta=theta),
nn.BatchNorm3d(channels),
nn.ReLU(),
MambaLayer(dim=channels),
ChannelAttention3D(in_channels=channels, reduction=2),
)
def forward(self, x):
[batch, channel, length, width, height] = x.shape
x = self.ConvBlock1(x)
x = self.MaxpoolSpa(x)
x = self.ConvBlock2(x)
x = self.ConvBlock3(x)
x = self.MaxpoolSpa(x)
# Process streams
s_x = self.ConvBlock4(x) # Slow stream
f_x = self.ConvBlock5(x) # Fast stream
# First set of blocks and fusion
s_x1 = self.Block1(s_x)
s_x1 = self.MaxpoolSpa(s_x1)
s_x1 = self.drop_1(s_x1)
f_x1 = self.Block4(f_x)
f_x1 = self.MaxpoolSpa(f_x1)
f_x1 = self.drop_2(f_x1)
s_x1 = self.fuse_1(s_x1,f_x1) # LateralConnection
# Second set of blocks and fusion
s_x2 = self.Block2(s_x1)
s_x2 = self.MaxpoolSpa(s_x2)
s_x2 = self.drop_3(s_x2)
f_x2 = self.Block5(f_x1)
f_x2 = self.MaxpoolSpa(f_x2)
f_x2 = self.drop_4(f_x2)
s_x2 = self.fuse_2(s_x2,f_x2) # LateralConnection
# Third blocks and upsampling
s_x3 = self.Block3(s_x2)
s_x3 = self.upsample1(s_x3)
s_x3 = self.drop_5(s_x3)
f_x3 = self.Block6(f_x2)
f_x3 = self.ConvBlock6(f_x3)
f_x3 = self.drop_6(f_x3)
# Final fusion and upsampling
x_fusion = torch.cat((f_x3, s_x3), dim=1)
x_final = self.upsample2(x_fusion)
x_final = self.poolspa(x_final)
x_final = self.ConvBlockLast(x_final)
rPPG = x_final.view(-1, length)
return rPPG