File size: 3,214 Bytes
760c94e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import h5py
import torch
import torch.utils.data as data
import torch.multiprocessing
import scipy.io as sio
from torch.nn import functional as F
# torch.multiprocessing.set_start_method('spawn')

class H5Dataset(data.Dataset):
    def __init__(self, H5Path):
        super(H5Dataset, self).__init__()
        self.H5File = h5py.File(H5Path,'r')       
        self.LeftData = self.H5File['LeftData']
        self.RightData = self.H5File['RightData']
        #self.LeftMask = self.H5File['LeftMask'][:] # update 2024.01.11 Masks loaded separately
        #self.RightMask = self.H5File['RightMask'][:]

    def __getitem__(self, index):
        return (torch.from_numpy(self.LeftData[index,:,:,:]).float(),
                torch.from_numpy(self.RightData[index,:,:,:]).float())
 
    def __len__(self):
        return self.LeftData.shape[0]

def save_image_mat(img_r, img_l, result_path, idx):
    save_data = {}
    save_data['recon_L'] = img_l.detach().cpu().numpy()
    save_data['recon_R'] = img_r.detach().cpu().numpy()
    sio.savemat(result_path+'img{}.mat'.format(idx), save_data)

def load_dataset(data_path, batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
    train_dir = data_path + '_train.h5'
    val_dir = data_path + '_val.h5'
    train_set = H5Dataset(train_dir)
    val_set = H5Dataset(val_dir)
    train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size, shuffle=False, **kwargs)
    val_loader = torch.utils.data.DataLoader(val_set,batch_size=batch_size, shuffle=False, **kwargs)
    return train_loader, val_loader

def load_dataset_test(data_path, batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
    test_dir = data_path + '.h5'
    test_set = H5Dataset(test_dir)
    test_loader = torch.utils.data.DataLoader(test_set,batch_size=batch_size, shuffle=False, **kwargs)
    return test_loader

# loss function # update 20240109 mask out zeros
def loss_function(xL, xR, x_recon_L, x_recon_R, mu, logvar, beta, left_mask, right_mask):

    Image_Size=xL.size(3)

    beta/=Image_Size**2
    
    # print('====> Image_Size: {} Beta: {:.8f}'.format(Image_Size, beta))

    # R_batch_size=xR.size(0)
    # Tutorial on VAE Page-14
    # log[P(X|z)] = C - \frac{1}{2} ||X-f(z)||^2 // \sigma^2 
    #             = C - \frac{1}{2} \sum_{i=1}^{N} ||X^{(i)}-f(z^{(i)}||^2 // \sigma^2
    #             = C - \farc{1}{2} N * F.mse_loss(Xhat-Xtrue) // \sigma^2
    # log[P(X|z)]-C = - \frac{1}{2}*2*192*192//\sigma^2 * F.mse_loss
    # Therefore, vae_beta = \frac{1}{36864//\sigma^2}
        
    # mask out zeros
    valid_mask_L = xL!=0
    valid_mask_R = xR!=0 

    if left_mask is not None:
        valid_mask_L = valid_mask_L & (left_mask.detach().to(torch.int32)==1)
        valid_mask_R = valid_mask_R & (right_mask.detach().to(torch.int32)==1)
    
    MSE_L = F.mse_loss(x_recon_L*valid_mask_L, xL*valid_mask_L, size_average=True)
    MSE_R = F.mse_loss(x_recon_R*valid_mask_R, xR *valid_mask_R, size_average=True)  

    # KLD is averaged across batch-samples
    KLD = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1).mean()

    return KLD * beta + MSE_L + MSE_R