from contrast_phys.PhysNetModel import PhysNet from utils_sig import * import matplotlib.pyplot as plt class PhysNet_Model: def __init__(self, model_path): self.device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") print("device",self.device) self.model = PhysNet(S=2).to(self.device).eval() # self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.rppg = [] self.fps = 30 # 默认30fps if "pt" in model_path: print("Testing uses pt!") weight = torch.load(model_path, map_location=self.device) import collections new_dict = collections.OrderedDict((old_key.replace("module.",""), value) for (old_key, value) in weight.items()) self.model.load_state_dict(new_dict) def predict(self, frame_list): # 模型预测 print("model processing") face_list = self.load_data(frame_list[-128:]) face_list = self.standardized_data(face_list) face_list_t = torch.tensor(face_list.astype('float32')).to(self.device) print("+++++face_list_t++++++++",face_list_t.shape) # need [1, 3, 128, 128, 128] rppg = self.model(face_list_t)[:,-1, :] rppg = rppg[0].detach().cpu().numpy()[20:100] print("model done") return rppg, face_list def predict_statistic(self, frame_list): # 模型预测 print("model processing") face_list = self.load_data(frame_list) face_list = self.standardized_data(face_list) face_list_t = torch.tensor(face_list.astype('float32')).to(self.device) print("+++++face_list_t++++++++",face_list_t.shape) # need [1, 3, 128, 128, 128] rppg = self.model(face_list_t)[:,-1, :] rppg = rppg[0].detach().cpu().numpy() rppg = rppg[20:len(rppg)-20] print("model done") return rppg, len(rppg), face_list def load_data(self,frame_list): # 处理输入的frame_list # face_list = face_detection(frame_list) face_list = [] for frame in frame_list: face_frame = cv2.resize(frame.astype('float32'), (128, 128), interpolation=cv2.INTER_AREA) face_list.append(face_frame) face_list = np.array(face_list) # (D, H, W, C) (N , C, D, H, W) print("============= face_list shape ==============",face_list.shape) # (180, 128, 128, 3) face_list = np.transpose(face_list, (3, 0, 1, 2)) # (C, D, H, W) face_list = np.array(face_list)[np.newaxis] # face_list = torch.tensor(face_list.astype('float32')).to(device) return face_list def plot(self): # 创建用于绘制脉搏波图的Matplotlib图形 hr, psd_y, psd_x = hr_fft(self.rppg, fs= self.fps) fig, (ax1, ax2) = plt.subplots(2, figsize=(20,10)) ax1.plot(np.arange(len(self.rppg))/self.fps, self.rppg) ax1.set_xlabel('time (sec)') ax1.grid('on') ax1.set_title('rPPG waveform') ax2.plot(psd_x, psd_y) ax2.set_xlabel('heart rate (bpm)') ax2.set_xlim([40,200]) ax2.grid('on') ax2.set_title('PSD') return fig def show(self): # 显示脉搏波图 plt.show() def standardized_data(self,data): """Z-score standardization for video data.""" data = data - np.mean(data) data = data / np.std(data) data[np.isnan(data)] = 0 return data from contrast_phys.DeepPhysModel import DeepPhys class DeepPhys_Model: def __init__(self, model_path): self.device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") self.model = DeepPhys(img_size=72).to(self.device).eval() # self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.rppg = [] self.fps = 30 # 默认30fps if "pt" in model_path: print("Testing uses pt!") weight = torch.load(model_path, map_location=self.device) import collections new_dict = collections.OrderedDict((old_key.replace("module.",""), value) for (old_key, value) in weight.items()) self.model.load_state_dict(new_dict) def predict(self, frame_list): # 模型预测 print("model processing") face_list = self.load_data(frame_list[-180:]) face_list_t = torch.tensor(face_list.astype('float32')).to(self.device) print("face_list_t shape =============",face_list_t.shape) # [120, 3, 72, 72] need [4*180, 6, 72, 72]) rppg = self.model(face_list_t).flatten() print("++++++++++++++rppg++++++++++++++++",rppg) rppg = rppg.detach().cpu().numpy()[20:100] print("model done") return rppg, face_list def load_data(self,frame_list): # 处理输入的frame_list # face_list = face_detection(frame_list) face_list = [] for frame in frame_list: face_frame = cv2.resize(frame, (72, 72)) face_list.append(face_frame) face_list = np.array(face_list) # (N, H, W, C) # (180, 72, 72, 3) frame_list_standardized = self.standardized_data(face_list) # (180, 72, 72, 3) frame_list_diff_normalize = self.diff_normalize_data(face_list) # concat frame_list_standardized and frame_list_diff_normalize at axis 3 face_list = np.concatenate((frame_list_standardized, frame_list_diff_normalize), axis=3) # (180, 72, 72, 6) N, H, W, C = face_list.shape # face_list = face_list.view(N * 1, C, H, W) face_list = np.transpose(face_list, (0, 3, 1, 2)) # (N, C, H, W) # face_list = np.array(face_list)[np.newaxis] # face_list = torch.tensor(face_list.astype('float32')).to(device) return face_list def plot(self): # 创建用于绘制脉搏波图的Matplotlib图形 hr, psd_y, psd_x = hr_fft(self.rppg, fs=self.fps) fig, (ax1, ax2) = plt.subplots(2, figsize=(20,10)) ax1.plot(np.arange(len(self.rppg))/self.fps, self.rppg) ax1.set_xlabel('time (sec)') ax1.grid('on') ax1.set_title('rPPG waveform') ax2.plot(psd_x, psd_y) ax2.set_xlabel('heart rate (bpm)') ax2.set_xlim([40,200]) ax2.grid('on') ax2.set_title('PSD') return fig def show(self): # 显示脉搏波图 plt.show() def standardized_data(self,data): """Z-score standardization for video data.""" data = data - np.mean(data) data = data / np.std(data) data[np.isnan(data)] = 0 return data def diff_normalize_data(self,data): """Calculate discrete difference in video data along the time-axis and nornamize by its standard deviation.""" n, h, w, c = data.shape diffnormalized_len = n - 1 diffnormalized_data = np.zeros((diffnormalized_len, h, w, c), dtype=np.float32) diffnormalized_data_padding = np.zeros((1, h, w, c), dtype=np.float32) for j in range(diffnormalized_len - 1): diffnormalized_data[j, :, :, :] = (data[j + 1, :, :, :] - data[j, :, :, :]) / ( data[j + 1, :, :, :] + data[j, :, :, :] + 1e-7) diffnormalized_data = diffnormalized_data / np.std(diffnormalized_data) diffnormalized_data = np.append(diffnormalized_data, diffnormalized_data_padding, axis=0) diffnormalized_data[np.isnan(diffnormalized_data)] = 0 return diffnormalized_data