myTest01 / analysis /sandbox_fid.py
meng2003's picture
Upload 357 files
2d5fdd1
import numpy as np
import sklearn
import pickle
from pathlib import Path
import scipy.linalg
import matplotlib.pyplot as plt
#%%
def FID(m,C,mg,Cg):
mean_diff = np.sum((m-mg)**2)
covar_diff = np.trace(C) + np.trace(Cg) -2 * np.trace(scipy.linalg.sqrtm(np.dot(C,Cg)))
return mean_diff + covar_diff
#%%
# feat_file = "inference/generated_1/moglow_expmap/predicted_mods/"+"aistpp_gBR_sBM_cAll_d04_mBR3_ch10.expmap_scaled_20.generated.npy"
# feats = np.load(feat_file)
#
# feats = feats[:,0,:]
# feats = np.delete(feats,[-4,-6],1)
#
# feats.shape
#
# C = np.dot(feats.T,feats)
#
# m = np.mean(feats,0)
# data_path="data/dance_combined"
# feature_name="expmap_scaled_20"
# transform_name="scaler"
# transform = pickle.load(open(Path(data_path).joinpath(feature_name+'_'+transform_name+'.pkl'), "rb"))
#
# C_data = transform.
#
# C_data.shape
#%%
root_dir = "data/fid_data/predicted_mods"
# experiment_name="moglow_expmap"
# stat="2moments" # mean and covariance of poses
stat="2moments_ext" # mean and covariance of 3 consecutive poses
moments_file = root_dir+"/"+"ground_truth"+"/bvh_expmap_cr_"+stat+".pkl"
gt_m, gt_C = pickle.load(open(moments_file,"rb"))
moments_dict = {}
fids = {}
experiments = ["moglow_expmap","transflower_expmap","transflower_expmap_finetune2_old","transformer_expmap"]
for experiment_name in experiments:
moments_file = root_dir+"/"+experiment_name+"/expmap_scaled_20.generated_"+stat+".pkl"
m,C = pickle.load(open(moments_file,"rb"))
if stat=="2moments":
m = np.delete(m,[-4,-6],0)
C = np.delete(C,[-4,-6],0)
C = np.delete(C,[-4,-6],1)
elif stat=="2moments_ext":
m = np.delete(m,[-4,-6],0)
m = np.delete(m,[-4-67,-6-67],0)
m = np.delete(m,[-4-67*2,-6-67*2],0)
C = np.delete(C,[-4,-6],0)
C = np.delete(C,[-4-67,-6-67],0)
C = np.delete(C,[-4-67*2,-6-67*2],0)
C = np.delete(C,[-4,-6],1)
C = np.delete(C,[-4-67,-6-67],1)
C = np.delete(C,[-4-67*2,-6-67*2],1)
moments_dict[experiment_name] = (m,C)
fids[experiment_name] = FID(m,C,gt_m,gt_C)
fids
#%%
#####
# for comparign seeds
root_dir_generated = "data/fid_data/predicted_mods_seed"
root_dir_gt = "data/fid_data/ground_truths"
fids = np.empty((5,5))
# stat="2moments" # mean and covariance of poses
stat="2moments_ext" # mean and covariance of 3 consecutive poses
# seeds = list(range(1,6))
for i in range(5):
gt_moments_file = root_dir_gt+"/"+str(i+1)+"/bvh_expmap_cr_"+stat+".pkl"
gt_m,gt_C = pickle.load(open(gt_moments_file,"rb"))
for j in range(5):
# moments_file = root_dir_generated+"/"+"generated_"+str(j+1)+"/expmap_scaled_20.generated_"+stat+".pkl"
moments_file = "inference/randomized_seeds/generated_"+str(j+1)+"/transflower_expmap/predicted_mods/expmap_scaled_20.generated_"+stat+".pkl"
m,C = pickle.load(open(moments_file,"rb"))
if stat=="2moments":
m = np.delete(m,[-4,-6],0)
C = np.delete(C,[-4,-6],0)
C = np.delete(C,[-4,-6],1)
elif stat=="2moments_ext":
m = np.delete(m,[-4,-6],0)
m = np.delete(m,[-4-67,-6-67],0)
m = np.delete(m,[-4-67*2,-6-67*2],0)
C = np.delete(C,[-4,-6],0)
C = np.delete(C,[-4-67,-6-67],0)
C = np.delete(C,[-4-67*2,-6-67*2],0)
C = np.delete(C,[-4,-6],1)
C = np.delete(C,[-4-67,-6-67],1)
C = np.delete(C,[-4-67*2,-6-67*2],1)
# moments_dict[experiment_name] = (m,C)
fids[i,j] = FID(m,C,gt_m,gt_C)
# for i in range(5):
# for j in range(i,5):
# fids[j,i] = fids[i,j]
fids
# plt.matshow(fids/np.mean(fids))
plt.matshow(fids)
# plt.matshow(fids[1:,1:])
plt.matshow(fids[1:,1:] == np.min(fids[1:,1:],0,keepdims=True))