Juliojuse's picture
init
fa926f8
raw
history blame
No virus
4.72 kB
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# -------------------------------------------------------------------------------------------------------------------
# PhysNet model
#
# the output is an ST-rPPG block rather than a rPPG signal.
# -------------------------------------------------------------------------------------------------------------------
class PhysNet(nn.Module):
def __init__(self, S=2, in_ch=3):
super().__init__()
self.S = S # S is the spatial dimension of ST-rPPG block
self.start = nn.Sequential(
nn.Conv3d(in_channels=in_ch, out_channels=32, kernel_size=(1, 5, 5), stride=1, padding=(0, 2, 2)),
nn.BatchNorm3d(32),
nn.ELU()
)
# 1x
self.loop1 = nn.Sequential(
nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0),
nn.Conv3d(in_channels=32, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU(),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU()
)
# encoder
self.encoder1 = nn.Sequential(
nn.AvgPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU(),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU(),
)
self.encoder2 = nn.Sequential(
nn.AvgPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU(),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU()
)
#
self.loop4 = nn.Sequential(
nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU(),
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)),
nn.BatchNorm3d(64),
nn.ELU()
)
# decoder to reach back initial temporal length
self.decoder1 = nn.Sequential(
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)),
nn.BatchNorm3d(64),
nn.ELU(),
)
self.decoder2 = nn.Sequential(
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)),
nn.BatchNorm3d(64),
nn.ELU()
)
self.end = nn.Sequential(
nn.AdaptiveAvgPool3d((None, S, S)),
nn.Conv3d(in_channels=64, out_channels=1, kernel_size=(1, 1, 1), stride=1, padding=(0, 0, 0))
)
def forward(self, x):
print("physet shape = ====================",x.shape)
means = torch.mean(x, dim=(2, 3, 4), keepdim=True)
stds = torch.std(x, dim=(2, 3, 4), keepdim=True)
x = (x - means) / stds # (B, C, T, 128, 128)
parity = []
x = self.start(x) # (B, C, T, 128, 128)
x = self.loop1(x) # (B, 64, T, 64, 64)
parity.append(x.size(2) % 2)
x = self.encoder1(x) # (B, 64, T/2, 32, 32)
parity.append(x.size(2) % 2)
x = self.encoder2(x) # (B, 64, T/4, 16, 16)
x = self.loop4(x) # (B, 64, T/4, 8, 8)
x = F.interpolate(x, scale_factor=(2, 1, 1)) # (B, 64, T/2, 8, 8)
x = self.decoder1(x) # (B, 64, T/2, 8, 8)
x = F.pad(x, (0,0,0,0,0,parity[-1]), mode='replicate')
x = F.interpolate(x, scale_factor=(2, 1, 1)) # (B, 64, T, 8, 8)
x = self.decoder2(x) # (B, 64, T, 8, 8)
x = F.pad(x, (0,0,0,0,0,parity[-2]), mode='replicate')
x = self.end(x) # (B, 1, T, S, S), ST-rPPG block
x_list = []
for a in range(self.S):
for b in range(self.S):
x_list.append(x[:,:,:,a,b]) # (B, 1, T)
x = sum(x_list)/(self.S*self.S) # (B, 1, T)
X = torch.cat(x_list+[x], 1) # (B, N, T), flatten all spatial signals to the second dimension
print("physet shape output = ====================",X.shape)
return X