Spaces:
Runtime error
Runtime error
import numpy as np | |
import sklearn | |
import pickle | |
from pathlib import Path | |
import scipy.linalg | |
import matplotlib.pyplot as plt | |
#%% | |
def FID(m,C,mg,Cg): | |
mean_diff = np.sum((m-mg)**2) | |
covar_diff = np.trace(C) + np.trace(Cg) -2 * np.trace(scipy.linalg.sqrtm(np.dot(C,Cg))) | |
return mean_diff + covar_diff | |
#%% | |
# feat_file = "inference/generated_1/moglow_expmap/predicted_mods/"+"aistpp_gBR_sBM_cAll_d04_mBR3_ch10.expmap_scaled_20.generated.npy" | |
# feats = np.load(feat_file) | |
# | |
# feats = feats[:,0,:] | |
# feats = np.delete(feats,[-4,-6],1) | |
# | |
# feats.shape | |
# | |
# C = np.dot(feats.T,feats) | |
# | |
# m = np.mean(feats,0) | |
# data_path="data/dance_combined" | |
# feature_name="expmap_scaled_20" | |
# transform_name="scaler" | |
# transform = pickle.load(open(Path(data_path).joinpath(feature_name+'_'+transform_name+'.pkl'), "rb")) | |
# | |
# C_data = transform. | |
# | |
# C_data.shape | |
#%% | |
root_dir = "data/fid_data/predicted_mods" | |
# experiment_name="moglow_expmap" | |
# stat="2moments" # mean and covariance of poses | |
stat="2moments_ext" # mean and covariance of 3 consecutive poses | |
moments_file = root_dir+"/"+"ground_truth"+"/bvh_expmap_cr_"+stat+".pkl" | |
gt_m, gt_C = pickle.load(open(moments_file,"rb")) | |
moments_dict = {} | |
fids = {} | |
experiments = ["moglow_expmap","transflower_expmap","transflower_expmap_finetune2_old","transformer_expmap"] | |
for experiment_name in experiments: | |
moments_file = root_dir+"/"+experiment_name+"/expmap_scaled_20.generated_"+stat+".pkl" | |
m,C = pickle.load(open(moments_file,"rb")) | |
if stat=="2moments": | |
m = np.delete(m,[-4,-6],0) | |
C = np.delete(C,[-4,-6],0) | |
C = np.delete(C,[-4,-6],1) | |
elif stat=="2moments_ext": | |
m = np.delete(m,[-4,-6],0) | |
m = np.delete(m,[-4-67,-6-67],0) | |
m = np.delete(m,[-4-67*2,-6-67*2],0) | |
C = np.delete(C,[-4,-6],0) | |
C = np.delete(C,[-4-67,-6-67],0) | |
C = np.delete(C,[-4-67*2,-6-67*2],0) | |
C = np.delete(C,[-4,-6],1) | |
C = np.delete(C,[-4-67,-6-67],1) | |
C = np.delete(C,[-4-67*2,-6-67*2],1) | |
moments_dict[experiment_name] = (m,C) | |
fids[experiment_name] = FID(m,C,gt_m,gt_C) | |
fids | |
#%% | |
##### | |
# for comparign seeds | |
root_dir_generated = "data/fid_data/predicted_mods_seed" | |
root_dir_gt = "data/fid_data/ground_truths" | |
fids = np.empty((5,5)) | |
# stat="2moments" # mean and covariance of poses | |
stat="2moments_ext" # mean and covariance of 3 consecutive poses | |
# seeds = list(range(1,6)) | |
for i in range(5): | |
gt_moments_file = root_dir_gt+"/"+str(i+1)+"/bvh_expmap_cr_"+stat+".pkl" | |
gt_m,gt_C = pickle.load(open(gt_moments_file,"rb")) | |
for j in range(5): | |
# moments_file = root_dir_generated+"/"+"generated_"+str(j+1)+"/expmap_scaled_20.generated_"+stat+".pkl" | |
moments_file = "inference/randomized_seeds/generated_"+str(j+1)+"/transflower_expmap/predicted_mods/expmap_scaled_20.generated_"+stat+".pkl" | |
m,C = pickle.load(open(moments_file,"rb")) | |
if stat=="2moments": | |
m = np.delete(m,[-4,-6],0) | |
C = np.delete(C,[-4,-6],0) | |
C = np.delete(C,[-4,-6],1) | |
elif stat=="2moments_ext": | |
m = np.delete(m,[-4,-6],0) | |
m = np.delete(m,[-4-67,-6-67],0) | |
m = np.delete(m,[-4-67*2,-6-67*2],0) | |
C = np.delete(C,[-4,-6],0) | |
C = np.delete(C,[-4-67,-6-67],0) | |
C = np.delete(C,[-4-67*2,-6-67*2],0) | |
C = np.delete(C,[-4,-6],1) | |
C = np.delete(C,[-4-67,-6-67],1) | |
C = np.delete(C,[-4-67*2,-6-67*2],1) | |
# moments_dict[experiment_name] = (m,C) | |
fids[i,j] = FID(m,C,gt_m,gt_C) | |
# for i in range(5): | |
# for j in range(i,5): | |
# fids[j,i] = fids[i,j] | |
fids | |
# plt.matshow(fids/np.mean(fids)) | |
plt.matshow(fids) | |
# plt.matshow(fids[1:,1:]) | |
plt.matshow(fids[1:,1:] == np.min(fids[1:,1:],0,keepdims=True)) | |