AK391
files
d380b77
raw history blame
No virus
9.73 kB
import logging
import math
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
import tqdm
from torch.utils.data import DataLoader
from saicinpainting.evaluation.utils import move_to_device
LOGGER = logging.getLogger(__name__)
class InpaintingEvaluator():
def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda',
integral_func=None, integral_title=None, clamp_image_range=None):
"""
:param dataset: torch.utils.data.Dataset which contains images and masks
:param scores: dict {score_name: EvaluatorScore object}
:param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples
which are defined by share of area occluded by mask
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
:param batch_size: batch_size for the dataloader
:param device: device to use
"""
self.scores = scores
self.dataset = dataset
self.area_grouping = area_grouping
self.bins = bins
self.device = torch.device(device)
self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size)
self.integral_func = integral_func
self.integral_title = integral_title
self.clamp_image_range = clamp_image_range
def _get_bin_edges(self):
bin_edges = np.linspace(0, 1, self.bins + 1)
num_digits = max(0, math.ceil(math.log10(self.bins)) - 1)
interval_names = []
for idx_bin in range(self.bins):
start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \
round(100 * bin_edges[idx_bin + 1], num_digits)
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
interval_names.append("{0}-{1}%".format(start_percent, end_percent))
groups = []
for batch in self.dataloader:
mask = batch['mask']
batch_size = mask.shape[0]
area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1)
bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1
# corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element
bin_indices[bin_indices == self.bins] = self.bins - 1
groups.append(bin_indices)
groups = np.hstack(groups)
return groups, interval_names
def evaluate(self, model=None):
"""
:param model: callable with signature (image_batch, mask_batch); should return inpainted_batch
:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
name of the particular group arranged by area of mask (e.g. '10-20%')
and score statistics for the group as values.
"""
results = dict()
if self.area_grouping:
groups, interval_names = self._get_bin_edges()
else:
groups = None
for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'):
score.to(self.device)
with torch.no_grad():
score.reset()
for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False):
batch = move_to_device(batch, self.device)
image_batch, mask_batch = batch['image'], batch['mask']
if self.clamp_image_range is not None:
image_batch = torch.clamp(image_batch,
min=self.clamp_image_range[0],
max=self.clamp_image_range[1])
if model is None:
assert 'inpainted' in batch, \
'Model is None, so we expected precomputed inpainting results at key "inpainted"'
inpainted_batch = batch['inpainted']
else:
inpainted_batch = model(image_batch, mask_batch)
score(inpainted_batch, image_batch, mask_batch)
total_results, group_results = score.get_value(groups=groups)
results[(score_name, 'total')] = total_results
if groups is not None:
for group_index, group_values in group_results.items():
group_name = interval_names[group_index]
results[(score_name, group_name)] = group_values
if self.integral_func is not None:
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
return results
def ssim_fid100_f1(metrics, fid_scale=100):
ssim = metrics[('ssim', 'total')]['mean']
fid = metrics[('fid', 'total')]['mean']
fid_rel = max(0, fid_scale - fid) / fid_scale
f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
return f1
def lpips_fid100_f1(metrics, fid_scale=100):
neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better
fid = metrics[('fid', 'total')]['mean']
fid_rel = max(0, fid_scale - fid) / fid_scale
f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3)
return f1
class InpaintingEvaluatorOnline(nn.Module):
def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted',
integral_func=None, integral_title=None, clamp_image_range=None):
"""
:param scores: dict {score_name: EvaluatorScore object}
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
:param device: device to use
"""
super().__init__()
LOGGER.info(f'{type(self)} init called')
self.scores = nn.ModuleDict(scores)
self.image_key = image_key
self.inpainted_key = inpainted_key
self.bins_num = bins
self.bin_edges = np.linspace(0, 1, self.bins_num + 1)
num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1)
self.interval_names = []
for idx_bin in range(self.bins_num):
start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \
round(100 * self.bin_edges[idx_bin + 1], num_digits)
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
self.interval_names.append("{0}-{1}%".format(start_percent, end_percent))
self.groups = []
self.integral_func = integral_func
self.integral_title = integral_title
self.clamp_image_range = clamp_image_range
LOGGER.info(f'{type(self)} init done')
def _get_bins(self, mask_batch):
batch_size = mask_batch.shape[0]
area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy()
bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1)
return bin_indices
def forward(self, batch: Dict[str, torch.Tensor]):
"""
Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end
:param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key)
"""
result = {}
with torch.no_grad():
image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key]
if self.clamp_image_range is not None:
image_batch = torch.clamp(image_batch,
min=self.clamp_image_range[0],
max=self.clamp_image_range[1])
self.groups.extend(self._get_bins(mask_batch))
for score_name, score in self.scores.items():
result[score_name] = score(inpainted_batch, image_batch, mask_batch)
return result
def process_batch(self, batch: Dict[str, torch.Tensor]):
return self(batch)
def evaluation_end(self, states=None):
""":return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
name of the particular group arranged by area of mask (e.g. '10-20%')
and score statistics for the group as values.
"""
LOGGER.info(f'{type(self)}: evaluation_end called')
self.groups = np.array(self.groups)
results = {}
for score_name, score in self.scores.items():
LOGGER.info(f'Getting value of {score_name}')
cur_states = [s[score_name] for s in states] if states is not None else None
total_results, group_results = score.get_value(groups=self.groups, states=cur_states)
LOGGER.info(f'Getting value of {score_name} done')
results[(score_name, 'total')] = total_results
for group_index, group_values in group_results.items():
group_name = self.interval_names[group_index]
results[(score_name, group_name)] = group_values
if self.integral_func is not None:
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
LOGGER.info(f'{type(self)}: reset scores')
self.groups = []
for sc in self.scores.values():
sc.reset()
LOGGER.info(f'{type(self)}: reset scores done')
LOGGER.info(f'{type(self)}: evaluation_end done')
return results