File size: 4,154 Bytes
8ca3a29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# python3.7
"""Computes the semantic directions regarding a specific image region."""

import os
import argparse
import numpy as np
from tqdm import tqdm

from coordinate import COORDINATES
from coordinate import get_mask
from utils.image_utils import save_image


def parse_args():
    """Parses arguments."""

    parser = argparse.ArgumentParser()
    parser.add_argument('jaco_path', type=str,
                        help='Path to jacobian matrix.')
    parser.add_argument('--region', type=str, default='eyes',
                        help='The region to be used to compute jacobian.')
    parser.add_argument('--save_dir', type=str, default='',
                        help='Directory to save the results. If not specified,'
                             'the results will be saved to '
                            '`work_dirs/{TASK_SPECIFIC}/` by default')
    parser.add_argument('--job', type=str, default='directions',
                        help='Name for the job (default: directions)')
    parser.add_argument('--name', type=str, default='resefa',
                        help='Name of help save the results.')
    parser.add_argument('--data_name', type=str, default='ffhq',
                        help='Name of the dataset.')
    parser.add_argument('--full_rank', action='store_true',
                        help='Whether or not to full rank background'
                             ' (default: False).')
    parser.add_argument('--tao', type=float, default=1e-3,
                        help='Coefficient to the identity matrix '
                             '(default: 1e-3).')
    return parser.parse_args()


def main():
    """Main function."""
    args = parse_args()
    assert os.path.exists(args.jaco_path)
    Jacobians = np.load(args.jaco_path)
    image_size = Jacobians.shape[2]
    w_dim = Jacobians.shape[-1]
    coord_dict = COORDINATES[args.data_name]
    assert args.region in coord_dict, \
        f'{args.region} coordinate is not defined in ' \
        f'COORDINATE_{args.data_name}. Please define this region first!'
    coords = coord_dict[args.region]
    mask = get_mask(image_size, coordinate=coords)
    foreground_ind = np.where(mask == 1)
    background_ind = np.where((1 - mask) == 1)
    temp_dir = f'./work_dirs/{args.job}/{args.data_name}/{args.region}'
    save_dir = args.save_dir or temp_dir
    os.makedirs(save_dir, exist_ok=True)
    for ind in tqdm(range(Jacobians.shape[0])):
        Jacobian = Jacobians[ind]
        if len(Jacobian.shape) == 4:  # [H, W, 1, latent_dim]
            Jaco_fore = Jacobian[foreground_ind[0], foreground_ind[1], 0]
            Jaco_back = Jacobian[background_ind[0], background_ind[1], 0]
        elif len(Jacobian.shape) == 5:  # [channel, H, W, 1, latent_dim]
            Jaco_fore = Jacobian[:, foreground_ind[0], foreground_ind[1], 0]
            Jaco_back = Jacobian[:, background_ind[0], background_ind[1], 0]
        else:
            raise ValueError('Shape of the Jacobian is not correct!')
        Jaco_fore = np.reshape(Jaco_fore, [-1, w_dim])
        Jaco_back = np.reshape(Jaco_back, [-1, w_dim])
        coef_f = 1 / Jaco_fore.shape[0]
        coef_b = 1 / Jaco_back.shape[0]
        M_fore = coef_f * Jaco_fore.T.dot(Jaco_fore)
        M_back = coef_b * Jaco_back.T.dot(Jaco_back)
        if args.full_rank:
            # J = J_b^TJ_b
            # J = (J + tao * trace(J) * I)
            print('Using full rank')
            coef = args.tao * np.trace(M_back)
            M_back = M_back + coef * np.identity(M_back.shape[0])
        # inv(B) * A = lambda x
        temp = np.linalg.inv(M_back).dot(M_fore)
        eig_val, eig_vec = np.linalg.eig(temp)
        eig_val = np.real(eig_val)
        eig_vec = np.real(eig_vec)
        directions = eig_vec.T
        directions = directions[np.argsort(-eig_val)]
        save_name = f'{save_dir}/image_{ind:02d}_region_{args.region}' \
                    f'_name_{args.name}'
        np.save(f'{save_name}.npy', directions)
        mask_i = np.tile(mask[:, :, np.newaxis], [1, 1, 3]) * 255
        save_image(f'{save_name}_mask.png', mask_i.astype(np.uint8))


if __name__ == '__main__':
    main()