Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# ------------------------------------------------------------------------------------------------------------------- | |
# PhysNet model | |
# | |
# the output is an ST-rPPG block rather than a rPPG signal. | |
# ------------------------------------------------------------------------------------------------------------------- | |
class PhysNet(nn.Module): | |
def __init__(self, S=2, in_ch=3): | |
super().__init__() | |
self.S = S # S is the spatial dimension of ST-rPPG block | |
self.start = nn.Sequential( | |
nn.Conv3d(in_channels=in_ch, out_channels=32, kernel_size=(1, 5, 5), stride=1, padding=(0, 2, 2)), | |
nn.BatchNorm3d(32), | |
nn.ELU() | |
) | |
# 1x | |
self.loop1 = nn.Sequential( | |
nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0), | |
nn.Conv3d(in_channels=32, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)), | |
nn.BatchNorm3d(64), | |
nn.ELU(), | |
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)), | |
nn.BatchNorm3d(64), | |
nn.ELU() | |
) | |
# encoder | |
self.encoder1 = nn.Sequential( | |
nn.AvgPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0), | |
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)), | |
nn.BatchNorm3d(64), | |
nn.ELU(), | |
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)), | |
nn.BatchNorm3d(64), | |
nn.ELU(), | |
) | |
self.encoder2 = nn.Sequential( | |
nn.AvgPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0), | |
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)), | |
nn.BatchNorm3d(64), | |
nn.ELU(), | |
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)), | |
nn.BatchNorm3d(64), | |
nn.ELU() | |
) | |
# | |
self.loop4 = nn.Sequential( | |
nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0), | |
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)), | |
nn.BatchNorm3d(64), | |
nn.ELU(), | |
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=(1, 1, 1)), | |
nn.BatchNorm3d(64), | |
nn.ELU() | |
) | |
# decoder to reach back initial temporal length | |
self.decoder1 = nn.Sequential( | |
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)), | |
nn.BatchNorm3d(64), | |
nn.ELU(), | |
) | |
self.decoder2 = nn.Sequential( | |
nn.Conv3d(in_channels=64, out_channels=64, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)), | |
nn.BatchNorm3d(64), | |
nn.ELU() | |
) | |
self.end = nn.Sequential( | |
nn.AdaptiveAvgPool3d((None, S, S)), | |
nn.Conv3d(in_channels=64, out_channels=1, kernel_size=(1, 1, 1), stride=1, padding=(0, 0, 0)) | |
) | |
def forward(self, x): | |
print("physet shape = ====================",x.shape) | |
means = torch.mean(x, dim=(2, 3, 4), keepdim=True) | |
stds = torch.std(x, dim=(2, 3, 4), keepdim=True) | |
x = (x - means) / stds # (B, C, T, 128, 128) | |
parity = [] | |
x = self.start(x) # (B, C, T, 128, 128) | |
x = self.loop1(x) # (B, 64, T, 64, 64) | |
parity.append(x.size(2) % 2) | |
x = self.encoder1(x) # (B, 64, T/2, 32, 32) | |
parity.append(x.size(2) % 2) | |
x = self.encoder2(x) # (B, 64, T/4, 16, 16) | |
x = self.loop4(x) # (B, 64, T/4, 8, 8) | |
x = F.interpolate(x, scale_factor=(2, 1, 1)) # (B, 64, T/2, 8, 8) | |
x = self.decoder1(x) # (B, 64, T/2, 8, 8) | |
x = F.pad(x, (0,0,0,0,0,parity[-1]), mode='replicate') | |
x = F.interpolate(x, scale_factor=(2, 1, 1)) # (B, 64, T, 8, 8) | |
x = self.decoder2(x) # (B, 64, T, 8, 8) | |
x = F.pad(x, (0,0,0,0,0,parity[-2]), mode='replicate') | |
x = self.end(x) # (B, 1, T, S, S), ST-rPPG block | |
x_list = [] | |
for a in range(self.S): | |
for b in range(self.S): | |
x_list.append(x[:,:,:,a,b]) # (B, 1, T) | |
x = sum(x_list)/(self.S*self.S) # (B, 1, T) | |
X = torch.cat(x_list+[x], 1) # (B, N, T), flatten all spatial signals to the second dimension | |
print("physet shape output = ====================",X.shape) | |
return X |