fffiloni's picture
Duplicate from fffiloni/lama
24eb05d
import logging
import math
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
import tqdm
from torch.utils.data import DataLoader
from saicinpainting.evaluation.utils import move_to_device
LOGGER = logging.getLogger(__name__)
class InpaintingEvaluator():
def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda',
integral_func=None, integral_title=None, clamp_image_range=None):
"""
:param dataset: torch.utils.data.Dataset which contains images and masks
:param scores: dict {score_name: EvaluatorScore object}
:param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples
which are defined by share of area occluded by mask
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
:param batch_size: batch_size for the dataloader
:param device: device to use
"""
self.scores = scores
self.dataset = dataset
self.area_grouping = area_grouping
self.bins = bins
self.device = torch.device(device)
self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size)
self.integral_func = integral_func
self.integral_title = integral_title
self.clamp_image_range = clamp_image_range
def _get_bin_edges(self):
bin_edges = np.linspace(0, 1, self.bins + 1)
num_digits = max(0, math.ceil(math.log10(self.bins)) - 1)
interval_names = []
for idx_bin in range(self.bins):
start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \
round(100 * bin_edges[idx_bin + 1], num_digits)
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
interval_names.append("{0}-{1}%".format(start_percent, end_percent))
groups = []
for batch in self.dataloader:
mask = batch['mask']
batch_size = mask.shape[0]
area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1)
bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1
# corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element
bin_indices[bin_indices == self.bins] = self.bins - 1
groups.append(bin_indices)
groups = np.hstack(groups)
return groups, interval_names
def evaluate(self, model=None):
"""
:param model: callable with signature (image_batch, mask_batch); should return inpainted_batch
:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
name of the particular group arranged by area of mask (e.g. '10-20%')
and score statistics for the group as values.
"""
results = dict()
if self.area_grouping:
groups, interval_names = self._get_bin_edges()
else:
groups = None
for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'):
score.to(self.device)
with torch.no_grad():
score.reset()
for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False):
batch = move_to_device(batch, self.device)
image_batch, mask_batch = batch['image'], batch['mask']
if self.clamp_image_range is not None:
image_batch = torch.clamp(image_batch,
min=self.clamp_image_range[0],
max=self.clamp_image_range[1])
if model is None:
assert 'inpainted' in batch, \
'Model is None, so we expected precomputed inpainting results at key "inpainted"'
inpainted_batch = batch['inpainted']
else:
inpainted_batch = model(image_batch, mask_batch)
score(inpainted_batch, image_batch, mask_batch)
total_results, group_results = score.get_value(groups=groups)
results[(score_name, 'total')] = total_results
if groups is not None:
for group_index, group_values in group_results.items():
group_name = interval_names[group_index]
results[(score_name, group_name)] = group_values
if self.integral_func is not None:
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
return results
def ssim_fid100_f1(metrics, fid_scale=100):
ssim = metrics[('ssim', 'total')]['mean']
fid = metrics[('fid', 'total')]['mean']
fid_rel = max(0, fid_scale - fid) / fid_scale
f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3)
return f1
def lpips_fid100_f1(metrics, fid_scale=100):
neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better
fid = metrics[('fid', 'total')]['mean']
fid_rel = max(0, fid_scale - fid) / fid_scale
f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3)
return f1
class InpaintingEvaluatorOnline(nn.Module):
def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted',
integral_func=None, integral_title=None, clamp_image_range=None):
"""
:param scores: dict {score_name: EvaluatorScore object}
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1)
:param device: device to use
"""
super().__init__()
LOGGER.info(f'{type(self)} init called')
self.scores = nn.ModuleDict(scores)
self.image_key = image_key
self.inpainted_key = inpainted_key
self.bins_num = bins
self.bin_edges = np.linspace(0, 1, self.bins_num + 1)
num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1)
self.interval_names = []
for idx_bin in range(self.bins_num):
start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \
round(100 * self.bin_edges[idx_bin + 1], num_digits)
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits)
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits)
self.interval_names.append("{0}-{1}%".format(start_percent, end_percent))
self.groups = []
self.integral_func = integral_func
self.integral_title = integral_title
self.clamp_image_range = clamp_image_range
LOGGER.info(f'{type(self)} init done')
def _get_bins(self, mask_batch):
batch_size = mask_batch.shape[0]
area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy()
bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1)
return bin_indices
def forward(self, batch: Dict[str, torch.Tensor]):
"""
Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end
:param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key)
"""
result = {}
with torch.no_grad():
image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key]
if self.clamp_image_range is not None:
image_batch = torch.clamp(image_batch,
min=self.clamp_image_range[0],
max=self.clamp_image_range[1])
self.groups.extend(self._get_bins(mask_batch))
for score_name, score in self.scores.items():
result[score_name] = score(inpainted_batch, image_batch, mask_batch)
return result
def process_batch(self, batch: Dict[str, torch.Tensor]):
return self(batch)
def evaluation_end(self, states=None):
""":return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or
name of the particular group arranged by area of mask (e.g. '10-20%')
and score statistics for the group as values.
"""
LOGGER.info(f'{type(self)}: evaluation_end called')
self.groups = np.array(self.groups)
results = {}
for score_name, score in self.scores.items():
LOGGER.info(f'Getting value of {score_name}')
cur_states = [s[score_name] for s in states] if states is not None else None
total_results, group_results = score.get_value(groups=self.groups, states=cur_states)
LOGGER.info(f'Getting value of {score_name} done')
results[(score_name, 'total')] = total_results
for group_index, group_values in group_results.items():
group_name = self.interval_names[group_index]
results[(score_name, group_name)] = group_values
if self.integral_func is not None:
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results))
LOGGER.info(f'{type(self)}: reset scores')
self.groups = []
for sc in self.scores.values():
sc.reset()
LOGGER.info(f'{type(self)}: reset scores done')
LOGGER.info(f'{type(self)}: evaluation_end done')
return results