Spaces:
Sleeping
Sleeping
File size: 5,528 Bytes
f4c3c2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import sys
import argparse
import torch
import numpy as np
from torch.utils.data import DataLoader
sys.path.append(".")
sys.path.append("..")
from configs import data_configs
from datasets.images_dataset import ImagesDataset
from utils.model_utils import setup_model
class LEC:
def __init__(self, net, is_cars=False):
"""
Latent Editing Consistency metric as proposed in the main paper.
:param net: e4e model loaded over the pSp framework.
:param is_cars: An indication as to whether or not to crop the middle of the StyleGAN's output images.
"""
self.net = net
self.is_cars = is_cars
def _encode(self, images):
"""
Encodes the given images into StyleGAN's latent space.
:param images: Tensor of shape NxCxHxW representing the images to be encoded.
:return: Tensor of shape NxKx512 representing the latent space embeddings of the given image (in W(K, *) space).
"""
codes = self.net.encoder(images)
assert codes.ndim == 3, f"Invalid latent codes shape, should be NxKx512 but is {codes.shape}"
# normalize with respect to the center of an average face
if self.net.opts.start_from_latent_avg:
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
return codes
def _generate(self, codes):
"""
Generate the StyleGAN2 images of the given codes
:param codes: Tensor of shape NxKx512 representing the StyleGAN's latent codes (in W(K, *) space).
:return: Tensor of shape NxCxHxW representing the generated images.
"""
images, _ = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=True)
images = self.net.face_pool(images)
if self.is_cars:
images = images[:, :, 32:224, :]
return images
@staticmethod
def _filter_outliers(arr):
arr = np.array(arr)
lo = np.percentile(arr, 1, interpolation="lower")
hi = np.percentile(arr, 99, interpolation="higher")
return np.extract(
np.logical_and(lo <= arr, arr <= hi), arr
)
def calculate_metric(self, data_loader, edit_function, inverse_edit_function):
"""
Calculate the LEC metric score.
:param data_loader: An iterable that returns a tuple of (images, _), similar to the training data loader.
:param edit_function: A function that receives latent codes and performs a semantically meaningful edit in the
latent space.
:param inverse_edit_function: A function that receives latent codes and performs the inverse edit of the
`edit_function` parameter.
:return: The LEC metric score.
"""
distances = []
with torch.no_grad():
for batch in data_loader:
x, _ = batch
inputs = x.to(device).float()
codes = self._encode(inputs)
edited_codes = edit_function(codes)
edited_image = self._generate(edited_codes)
edited_image_inversion_codes = self._encode(edited_image)
inverse_edit_codes = inverse_edit_function(edited_image_inversion_codes)
dist = (codes - inverse_edit_codes).norm(2, dim=(1, 2)).mean()
distances.append(dist.to("cpu").numpy())
distances = self._filter_outliers(distances)
return distances.mean()
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(description="LEC metric calculator")
parser.add_argument("--batch", type=int, default=8, help="batch size for the models")
parser.add_argument("--images_dir", type=str, default=None,
help="Path to the images directory on which we calculate the LEC score")
parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to the model checkpoints")
args = parser.parse_args()
print(args)
net, opts = setup_model(args.ckpt, device)
dataset_args = data_configs.DATASETS[opts.dataset_type]
transforms_dict = dataset_args['transforms'](opts).get_transforms()
images_directory = dataset_args['test_source_root'] if args.images_dir is None else args.images_dir
test_dataset = ImagesDataset(source_root=images_directory,
target_root=images_directory,
source_transform=transforms_dict['transform_source'],
target_transform=transforms_dict['transform_test'],
opts=opts)
data_loader = DataLoader(test_dataset,
batch_size=args.batch,
shuffle=False,
num_workers=2,
drop_last=True)
print(f'dataset length: {len(test_dataset)}')
# In the following example, we are using an InterfaceGAN based editing to calculate the LEC metric.
# Change the provided example according to your domain and needs.
direction = torch.load('../editings/interfacegan_directions/age.pt').to(device)
def edit_func_example(codes):
return codes + 3 * direction
def inverse_edit_func_example(codes):
return codes - 3 * direction
lec = LEC(net, is_cars='car' in opts.dataset_type)
result = lec.calculate_metric(data_loader, edit_func_example, inverse_edit_func_example)
print(f"LEC: {result}")
|