File size: 2,162 Bytes
7629b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

# some parts of the code adapted from https://github.com/benjiebob/WLDO and https://github.com/benjiebob/SMALify

import numpy as np
import torch
import pickle as pkl



class ShapePrior(torch.nn.Module):   
    def __init__(self, prior_path):   
        super(ShapePrior, self).__init__()
        try:
            with open(prior_path, 'r') as f:
                res = pkl.load(f)
        except (UnicodeDecodeError, TypeError) as e:
            with open(prior_path, 'rb') as file:
                u = pkl._Unpickler(file)
                u.encoding = 'latin1'
                res = u.load()
        betas_mean = res['dog_cluster_mean']  
        betas_cov = res['dog_cluster_cov']
        single_gaussian_inv_covs = np.linalg.inv(betas_cov + 1e-5 * np.eye(betas_cov.shape[0]))  
        single_gaussian_precs = torch.tensor(np.linalg.cholesky(single_gaussian_inv_covs)).float()
        single_gaussian_means = torch.tensor(betas_mean).float()
        self.register_buffer('single_gaussian_precs', single_gaussian_precs)    # (20, 20)
        self.register_buffer('single_gaussian_means', single_gaussian_means)    # (20)
        use_ind_tch = torch.from_numpy(np.ones(single_gaussian_means.shape[0], dtype=bool)).float()   # .to(device)
        self.register_buffer('use_ind_tch', use_ind_tch)

    def forward(self, betas_smal_orig, use_singe_gaussian=False):      
        n_betas_smal = betas_smal_orig.shape[1]
        device = betas_smal_orig.device
        use_ind_tch_corrected = self.use_ind_tch * torch.cat((torch.ones_like(self.use_ind_tch[:n_betas_smal]), torch.zeros_like(self.use_ind_tch[n_betas_smal:])))        
        samples = torch.cat((betas_smal_orig, torch.zeros((betas_smal_orig.shape[0], self.single_gaussian_means.shape[0]-n_betas_smal)).float().to(device)), dim=1)
        mean_sub =  samples - self.single_gaussian_means.unsqueeze(0)
        single_gaussian_precs_corr = self.single_gaussian_precs * use_ind_tch_corrected[:, None] * use_ind_tch_corrected[None, :]
        res = torch.tensordot(mean_sub, single_gaussian_precs_corr, dims = ([1], [0]))
        res_final_mean_2 = torch.mean(res ** 2)
        return res_final_mean_2