File size: 3,791 Bytes
2d5fdd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import sklearn
import pickle
from pathlib import Path
import scipy.linalg
import matplotlib.pyplot as plt
#%%

def FID(m,C,mg,Cg):
    mean_diff = np.sum((m-mg)**2)
    covar_diff = np.trace(C) + np.trace(Cg) -2 * np.trace(scipy.linalg.sqrtm(np.dot(C,Cg)))
    return mean_diff + covar_diff
#%%

# feat_file = "inference/generated_1/moglow_expmap/predicted_mods/"+"aistpp_gBR_sBM_cAll_d04_mBR3_ch10.expmap_scaled_20.generated.npy"
# feats = np.load(feat_file)
#
# feats = feats[:,0,:]
# feats = np.delete(feats,[-4,-6],1)
#
# feats.shape
#
# C = np.dot(feats.T,feats)
#
# m = np.mean(feats,0)

# data_path="data/dance_combined"
# feature_name="expmap_scaled_20"
# transform_name="scaler"
# transform = pickle.load(open(Path(data_path).joinpath(feature_name+'_'+transform_name+'.pkl'), "rb"))
#
# C_data = transform.
#
# C_data.shape

#%%
root_dir = "data/fid_data/predicted_mods"
# experiment_name="moglow_expmap"

# stat="2moments" # mean and covariance of poses
stat="2moments_ext" # mean and covariance of 3 consecutive poses
moments_file = root_dir+"/"+"ground_truth"+"/bvh_expmap_cr_"+stat+".pkl"
gt_m, gt_C = pickle.load(open(moments_file,"rb"))

moments_dict = {}
fids = {}
experiments = ["moglow_expmap","transflower_expmap","transflower_expmap_finetune2_old","transformer_expmap"]
for experiment_name in experiments:
    moments_file = root_dir+"/"+experiment_name+"/expmap_scaled_20.generated_"+stat+".pkl"

    m,C = pickle.load(open(moments_file,"rb"))
    if stat=="2moments":
        m = np.delete(m,[-4,-6],0)
        C = np.delete(C,[-4,-6],0)
        C = np.delete(C,[-4,-6],1)
    elif stat=="2moments_ext":
        m = np.delete(m,[-4,-6],0)
        m = np.delete(m,[-4-67,-6-67],0)
        m = np.delete(m,[-4-67*2,-6-67*2],0)
        C = np.delete(C,[-4,-6],0)
        C = np.delete(C,[-4-67,-6-67],0)
        C = np.delete(C,[-4-67*2,-6-67*2],0)
        C = np.delete(C,[-4,-6],1)
        C = np.delete(C,[-4-67,-6-67],1)
        C = np.delete(C,[-4-67*2,-6-67*2],1)
    moments_dict[experiment_name] = (m,C)
    fids[experiment_name] = FID(m,C,gt_m,gt_C)


fids
#%%

#####
# for comparign seeds

root_dir_generated = "data/fid_data/predicted_mods_seed"
root_dir_gt = "data/fid_data/ground_truths"
fids = np.empty((5,5))
# stat="2moments" # mean and covariance of poses
stat="2moments_ext" # mean and covariance of 3 consecutive poses
# seeds = list(range(1,6))
for i in range(5):
    gt_moments_file = root_dir_gt+"/"+str(i+1)+"/bvh_expmap_cr_"+stat+".pkl"
    gt_m,gt_C = pickle.load(open(gt_moments_file,"rb"))
    for j in range(5):
        # moments_file = root_dir_generated+"/"+"generated_"+str(j+1)+"/expmap_scaled_20.generated_"+stat+".pkl"
        moments_file = "inference/randomized_seeds/generated_"+str(j+1)+"/transflower_expmap/predicted_mods/expmap_scaled_20.generated_"+stat+".pkl"

        m,C = pickle.load(open(moments_file,"rb"))
        if stat=="2moments":
            m = np.delete(m,[-4,-6],0)
            C = np.delete(C,[-4,-6],0)
            C = np.delete(C,[-4,-6],1)
        elif stat=="2moments_ext":
            m = np.delete(m,[-4,-6],0)
            m = np.delete(m,[-4-67,-6-67],0)
            m = np.delete(m,[-4-67*2,-6-67*2],0)
            C = np.delete(C,[-4,-6],0)
            C = np.delete(C,[-4-67,-6-67],0)
            C = np.delete(C,[-4-67*2,-6-67*2],0)
            C = np.delete(C,[-4,-6],1)
            C = np.delete(C,[-4-67,-6-67],1)
            C = np.delete(C,[-4-67*2,-6-67*2],1)
        # moments_dict[experiment_name] = (m,C)
        fids[i,j] = FID(m,C,gt_m,gt_C)

# for i in range(5):
#     for j in range(i,5):
#         fids[j,i] = fids[i,j]


fids

# plt.matshow(fids/np.mean(fids))
plt.matshow(fids)
# plt.matshow(fids[1:,1:])
plt.matshow(fids[1:,1:] == np.min(fids[1:,1:],0,keepdims=True))