Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import sklearn | |
| import pickle | |
| from pathlib import Path | |
| import scipy.linalg | |
| import matplotlib.pyplot as plt | |
| #%% | |
| def FID(m,C,mg,Cg): | |
| mean_diff = np.sum((m-mg)**2) | |
| covar_diff = np.trace(C) + np.trace(Cg) -2 * np.trace(scipy.linalg.sqrtm(np.dot(C,Cg))) | |
| return mean_diff + covar_diff | |
| #%% | |
| # feat_file = "inference/generated_1/moglow_expmap/predicted_mods/"+"aistpp_gBR_sBM_cAll_d04_mBR3_ch10.expmap_scaled_20.generated.npy" | |
| # feats = np.load(feat_file) | |
| # | |
| # feats = feats[:,0,:] | |
| # feats = np.delete(feats,[-4,-6],1) | |
| # | |
| # feats.shape | |
| # | |
| # C = np.dot(feats.T,feats) | |
| # | |
| # m = np.mean(feats,0) | |
| # data_path="data/dance_combined" | |
| # feature_name="expmap_scaled_20" | |
| # transform_name="scaler" | |
| # transform = pickle.load(open(Path(data_path).joinpath(feature_name+'_'+transform_name+'.pkl'), "rb")) | |
| # | |
| # C_data = transform. | |
| # | |
| # C_data.shape | |
| #%% | |
| root_dir = "data/fid_data/predicted_mods" | |
| # experiment_name="moglow_expmap" | |
| # stat="2moments" # mean and covariance of poses | |
| stat="2moments_ext" # mean and covariance of 3 consecutive poses | |
| moments_file = root_dir+"/"+"ground_truth"+"/bvh_expmap_cr_"+stat+".pkl" | |
| gt_m, gt_C = pickle.load(open(moments_file,"rb")) | |
| moments_dict = {} | |
| fids = {} | |
| experiments = ["moglow_expmap","transflower_expmap","transflower_expmap_finetune2_old","transformer_expmap"] | |
| for experiment_name in experiments: | |
| moments_file = root_dir+"/"+experiment_name+"/expmap_scaled_20.generated_"+stat+".pkl" | |
| m,C = pickle.load(open(moments_file,"rb")) | |
| if stat=="2moments": | |
| m = np.delete(m,[-4,-6],0) | |
| C = np.delete(C,[-4,-6],0) | |
| C = np.delete(C,[-4,-6],1) | |
| elif stat=="2moments_ext": | |
| m = np.delete(m,[-4,-6],0) | |
| m = np.delete(m,[-4-67,-6-67],0) | |
| m = np.delete(m,[-4-67*2,-6-67*2],0) | |
| C = np.delete(C,[-4,-6],0) | |
| C = np.delete(C,[-4-67,-6-67],0) | |
| C = np.delete(C,[-4-67*2,-6-67*2],0) | |
| C = np.delete(C,[-4,-6],1) | |
| C = np.delete(C,[-4-67,-6-67],1) | |
| C = np.delete(C,[-4-67*2,-6-67*2],1) | |
| moments_dict[experiment_name] = (m,C) | |
| fids[experiment_name] = FID(m,C,gt_m,gt_C) | |
| fids | |
| #%% | |
| ##### | |
| # for comparign seeds | |
| root_dir_generated = "data/fid_data/predicted_mods_seed" | |
| root_dir_gt = "data/fid_data/ground_truths" | |
| fids = np.empty((5,5)) | |
| # stat="2moments" # mean and covariance of poses | |
| stat="2moments_ext" # mean and covariance of 3 consecutive poses | |
| # seeds = list(range(1,6)) | |
| for i in range(5): | |
| gt_moments_file = root_dir_gt+"/"+str(i+1)+"/bvh_expmap_cr_"+stat+".pkl" | |
| gt_m,gt_C = pickle.load(open(gt_moments_file,"rb")) | |
| for j in range(5): | |
| # moments_file = root_dir_generated+"/"+"generated_"+str(j+1)+"/expmap_scaled_20.generated_"+stat+".pkl" | |
| moments_file = "inference/randomized_seeds/generated_"+str(j+1)+"/transflower_expmap/predicted_mods/expmap_scaled_20.generated_"+stat+".pkl" | |
| m,C = pickle.load(open(moments_file,"rb")) | |
| if stat=="2moments": | |
| m = np.delete(m,[-4,-6],0) | |
| C = np.delete(C,[-4,-6],0) | |
| C = np.delete(C,[-4,-6],1) | |
| elif stat=="2moments_ext": | |
| m = np.delete(m,[-4,-6],0) | |
| m = np.delete(m,[-4-67,-6-67],0) | |
| m = np.delete(m,[-4-67*2,-6-67*2],0) | |
| C = np.delete(C,[-4,-6],0) | |
| C = np.delete(C,[-4-67,-6-67],0) | |
| C = np.delete(C,[-4-67*2,-6-67*2],0) | |
| C = np.delete(C,[-4,-6],1) | |
| C = np.delete(C,[-4-67,-6-67],1) | |
| C = np.delete(C,[-4-67*2,-6-67*2],1) | |
| # moments_dict[experiment_name] = (m,C) | |
| fids[i,j] = FID(m,C,gt_m,gt_C) | |
| # for i in range(5): | |
| # for j in range(i,5): | |
| # fids[j,i] = fids[i,j] | |
| fids | |
| # plt.matshow(fids/np.mean(fids)) | |
| plt.matshow(fids) | |
| # plt.matshow(fids[1:,1:]) | |
| plt.matshow(fids[1:,1:] == np.min(fids[1:,1:],0,keepdims=True)) | |