demo-space2 / neural_methods /model /EfficientPhys.py
swetchareddytukkani
Initial commit with PhysMamba rPPG application
1c6711c
"""EfficientPhys: Enabling Simple, Fast and Accurate Camera-Based Vitals Measurement
Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV 2023)
Xin Liu, Brial Hill, Ziheng Jiang, Shwetak Patel, Daniel McDuff
"""
import torch
import torch.nn as nn
class Attention_mask(nn.Module):
def __init__(self):
super(Attention_mask, self).__init__()
def forward(self, x):
xsum = torch.sum(x, dim=2, keepdim=True)
xsum = torch.sum(xsum, dim=3, keepdim=True)
xshape = tuple(x.size())
return x / xsum * xshape[2] * xshape[3] * 0.5
def get_config(self):
"""May be generated manually. """
config = super(Attention_mask, self).get_config()
return config
class TSM(nn.Module):
def __init__(self, n_segment=10, fold_div=3):
super(TSM, self).__init__()
self.n_segment = n_segment
self.fold_div = fold_div
def forward(self, x):
nt, c, h, w = x.size()
n_batch = nt // self.n_segment
x = x.view(n_batch, self.n_segment, c, h, w)
fold = c // self.fold_div
out = torch.zeros_like(x)
out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
return out.view(nt, c, h, w)
class EfficientPhys(nn.Module):
def __init__(self, in_channels=3, nb_filters1=32, nb_filters2=64, kernel_size=3, dropout_rate1=0.25,
dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128, frame_depth=20, img_size=36, channel='raw'):
super(EfficientPhys, self).__init__()
self.in_channels = in_channels
self.kernel_size = kernel_size
self.dropout_rate1 = dropout_rate1
self.dropout_rate2 = dropout_rate2
self.pool_size = pool_size
self.nb_filters1 = nb_filters1
self.nb_filters2 = nb_filters2
self.nb_dense = nb_dense
# TSM layers
self.TSM_1 = TSM(n_segment=frame_depth)
self.TSM_2 = TSM(n_segment=frame_depth)
self.TSM_3 = TSM(n_segment=frame_depth)
self.TSM_4 = TSM(n_segment=frame_depth)
# Motion branch convs
self.motion_conv1 = nn.Conv2d(self.in_channels, self.nb_filters1, kernel_size=self.kernel_size, padding=(1, 1),
bias=True)
self.motion_conv2 = nn.Conv2d(self.nb_filters1, self.nb_filters1, kernel_size=self.kernel_size, bias=True)
self.motion_conv3 = nn.Conv2d(self.nb_filters1, self.nb_filters2, kernel_size=self.kernel_size, padding=(1, 1),
bias=True)
self.motion_conv4 = nn.Conv2d(self.nb_filters2, self.nb_filters2, kernel_size=self.kernel_size, bias=True)
# Attention layers
self.apperance_att_conv1 = nn.Conv2d(self.nb_filters1, 1, kernel_size=1, padding=(0, 0), bias=True)
self.attn_mask_1 = Attention_mask()
self.apperance_att_conv2 = nn.Conv2d(self.nb_filters2, 1, kernel_size=1, padding=(0, 0), bias=True)
self.attn_mask_2 = Attention_mask()
# Avg pooling
self.avg_pooling_1 = nn.AvgPool2d(self.pool_size)
self.avg_pooling_2 = nn.AvgPool2d(self.pool_size)
self.avg_pooling_3 = nn.AvgPool2d(self.pool_size)
# Dropout layers
self.dropout_1 = nn.Dropout(self.dropout_rate1)
self.dropout_2 = nn.Dropout(self.dropout_rate1)
self.dropout_3 = nn.Dropout(self.dropout_rate1)
self.dropout_4 = nn.Dropout(self.dropout_rate2)
# Dense layers
if img_size == 36:
self.final_dense_1 = nn.Linear(3136, self.nb_dense, bias=True)
elif img_size == 72:
self.final_dense_1 = nn.Linear(16384, self.nb_dense, bias=True)
elif img_size == 96:
self.final_dense_1 = nn.Linear(30976, self.nb_dense, bias=True)
else:
raise Exception('Unsupported image size')
self.final_dense_2 = nn.Linear(self.nb_dense, 1, bias=True)
self.batch_norm = nn.BatchNorm2d(3)
self.channel = channel
def forward(self, inputs, params=None):
inputs = torch.diff(inputs, dim=0)
inputs = self.batch_norm(inputs)
network_input = self.TSM_1(inputs)
d1 = torch.tanh(self.motion_conv1(network_input))
d1 = self.TSM_2(d1)
d2 = torch.tanh(self.motion_conv2(d1))
g1 = torch.sigmoid(self.apperance_att_conv1(d2))
g1 = self.attn_mask_1(g1)
gated1 = d2 * g1
d3 = self.avg_pooling_1(gated1)
d4 = self.dropout_1(d3)
d4 = self.TSM_3(d4)
d5 = torch.tanh(self.motion_conv3(d4))
d5 = self.TSM_4(d5)
d6 = torch.tanh(self.motion_conv4(d5))
g2 = torch.sigmoid(self.apperance_att_conv2(d6))
g2 = self.attn_mask_2(g2)
gated2 = d6 * g2
d7 = self.avg_pooling_3(gated2)
d8 = self.dropout_3(d7)
d9 = d8.view(d8.size(0), -1)
d10 = torch.tanh(self.final_dense_1(d9))
d11 = self.dropout_4(d10)
out = self.final_dense_2(d11)
return out