Juliojuse's picture
init
fa926f8
raw
history blame
No virus
7.59 kB
from contrast_phys.PhysNetModel import PhysNet
from utils_sig import *
import matplotlib.pyplot as plt
class PhysNet_Model:
def __init__(self, model_path):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
print("device",self.device)
self.model = PhysNet(S=2).to(self.device).eval()
# self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.rppg = []
self.fps = 30 # 默认30fps
if "pt" in model_path:
print("Testing uses pt!")
weight = torch.load(model_path, map_location=self.device)
import collections
new_dict = collections.OrderedDict((old_key.replace("module.",""), value) for (old_key, value) in weight.items())
self.model.load_state_dict(new_dict)
def predict(self, frame_list):
# 模型预测
print("model processing")
face_list = self.load_data(frame_list[-128:])
face_list = self.standardized_data(face_list)
face_list_t = torch.tensor(face_list.astype('float32')).to(self.device)
print("+++++face_list_t++++++++",face_list_t.shape) # need [1, 3, 128, 128, 128]
rppg = self.model(face_list_t)[:,-1, :]
rppg = rppg[0].detach().cpu().numpy()[20:100]
print("model done")
return rppg, face_list
def predict_statistic(self, frame_list):
# 模型预测
print("model processing")
face_list = self.load_data(frame_list)
face_list = self.standardized_data(face_list)
face_list_t = torch.tensor(face_list.astype('float32')).to(self.device)
print("+++++face_list_t++++++++",face_list_t.shape) # need [1, 3, 128, 128, 128]
rppg = self.model(face_list_t)[:,-1, :]
rppg = rppg[0].detach().cpu().numpy()
rppg = rppg[20:len(rppg)-20]
print("model done")
return rppg, len(rppg), face_list
def load_data(self,frame_list):
# 处理输入的frame_list
# face_list = face_detection(frame_list)
face_list = []
for frame in frame_list:
face_frame = cv2.resize(frame.astype('float32'), (128, 128), interpolation=cv2.INTER_AREA)
face_list.append(face_frame)
face_list = np.array(face_list) # (D, H, W, C) (N , C, D, H, W)
print("============= face_list shape ==============",face_list.shape) # (180, 128, 128, 3)
face_list = np.transpose(face_list, (3, 0, 1, 2)) # (C, D, H, W)
face_list = np.array(face_list)[np.newaxis]
# face_list = torch.tensor(face_list.astype('float32')).to(device)
return face_list
def plot(self):
# 创建用于绘制脉搏波图的Matplotlib图形
hr, psd_y, psd_x = hr_fft(self.rppg, fs= self.fps)
fig, (ax1, ax2) = plt.subplots(2, figsize=(20,10))
ax1.plot(np.arange(len(self.rppg))/self.fps, self.rppg)
ax1.set_xlabel('time (sec)')
ax1.grid('on')
ax1.set_title('rPPG waveform')
ax2.plot(psd_x, psd_y)
ax2.set_xlabel('heart rate (bpm)')
ax2.set_xlim([40,200])
ax2.grid('on')
ax2.set_title('PSD')
return fig
def show(self):
# 显示脉搏波图
plt.show()
def standardized_data(self,data):
"""Z-score standardization for video data."""
data = data - np.mean(data)
data = data / np.std(data)
data[np.isnan(data)] = 0
return data
from contrast_phys.DeepPhysModel import DeepPhys
class DeepPhys_Model:
def __init__(self, model_path):
self.device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
self.model = DeepPhys(img_size=72).to(self.device).eval()
# self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.rppg = []
self.fps = 30 # 默认30fps
if "pt" in model_path:
print("Testing uses pt!")
weight = torch.load(model_path, map_location=self.device)
import collections
new_dict = collections.OrderedDict((old_key.replace("module.",""), value) for (old_key, value) in weight.items())
self.model.load_state_dict(new_dict)
def predict(self, frame_list):
# 模型预测
print("model processing")
face_list = self.load_data(frame_list[-180:])
face_list_t = torch.tensor(face_list.astype('float32')).to(self.device)
print("face_list_t shape =============",face_list_t.shape) # [120, 3, 72, 72] need [4*180, 6, 72, 72])
rppg = self.model(face_list_t).flatten()
print("++++++++++++++rppg++++++++++++++++",rppg)
rppg = rppg.detach().cpu().numpy()[20:100]
print("model done")
return rppg, face_list
def load_data(self,frame_list):
# 处理输入的frame_list
# face_list = face_detection(frame_list)
face_list = []
for frame in frame_list:
face_frame = cv2.resize(frame, (72, 72))
face_list.append(face_frame)
face_list = np.array(face_list) # (N, H, W, C) # (180, 72, 72, 3)
frame_list_standardized = self.standardized_data(face_list) # (180, 72, 72, 3)
frame_list_diff_normalize = self.diff_normalize_data(face_list)
# concat frame_list_standardized and frame_list_diff_normalize at axis 3
face_list = np.concatenate((frame_list_standardized, frame_list_diff_normalize), axis=3) # (180, 72, 72, 6)
N, H, W, C = face_list.shape
# face_list = face_list.view(N * 1, C, H, W)
face_list = np.transpose(face_list, (0, 3, 1, 2)) # (N, C, H, W)
# face_list = np.array(face_list)[np.newaxis]
# face_list = torch.tensor(face_list.astype('float32')).to(device)
return face_list
def plot(self):
# 创建用于绘制脉搏波图的Matplotlib图形
hr, psd_y, psd_x = hr_fft(self.rppg, fs=self.fps)
fig, (ax1, ax2) = plt.subplots(2, figsize=(20,10))
ax1.plot(np.arange(len(self.rppg))/self.fps, self.rppg)
ax1.set_xlabel('time (sec)')
ax1.grid('on')
ax1.set_title('rPPG waveform')
ax2.plot(psd_x, psd_y)
ax2.set_xlabel('heart rate (bpm)')
ax2.set_xlim([40,200])
ax2.grid('on')
ax2.set_title('PSD')
return fig
def show(self):
# 显示脉搏波图
plt.show()
def standardized_data(self,data):
"""Z-score standardization for video data."""
data = data - np.mean(data)
data = data / np.std(data)
data[np.isnan(data)] = 0
return data
def diff_normalize_data(self,data):
"""Calculate discrete difference in video data along the time-axis and nornamize by its standard deviation."""
n, h, w, c = data.shape
diffnormalized_len = n - 1
diffnormalized_data = np.zeros((diffnormalized_len, h, w, c), dtype=np.float32)
diffnormalized_data_padding = np.zeros((1, h, w, c), dtype=np.float32)
for j in range(diffnormalized_len - 1):
diffnormalized_data[j, :, :, :] = (data[j + 1, :, :, :] - data[j, :, :, :]) / (
data[j + 1, :, :, :] + data[j, :, :, :] + 1e-7)
diffnormalized_data = diffnormalized_data / np.std(diffnormalized_data)
diffnormalized_data = np.append(diffnormalized_data, diffnormalized_data_padding, axis=0)
diffnormalized_data[np.isnan(diffnormalized_data)] = 0
return diffnormalized_data