AK391
files
d380b77
raw
history blame
23.2 kB
import logging
from abc import abstractmethod, ABC
import numpy as np
import sklearn
import sklearn.svm
import torch
import torch.nn as nn
import torch.nn.functional as F
from joblib import Parallel, delayed
from scipy import linalg
from models.ade20k import SegmentationModule, NUM_CLASS, segm_options
from .fid.inception import InceptionV3
from .lpips import PerceptualLoss
from .ssim import SSIM
LOGGER = logging.getLogger(__name__)
def get_groupings(groups):
"""
:param groups: group numbers for respective elements
:return: dict of kind {group_idx: indices of the corresponding group elements}
"""
label_groups, count_groups = np.unique(groups, return_counts=True)
indices = np.argsort(groups)
grouping = dict()
cur_start = 0
for label, count in zip(label_groups, count_groups):
cur_end = cur_start + count
cur_indices = indices[cur_start:cur_end]
grouping[label] = cur_indices
cur_start = cur_end
return grouping
class EvaluatorScore(nn.Module):
@abstractmethod
def forward(self, pred_batch, target_batch, mask):
pass
@abstractmethod
def get_value(self, groups=None, states=None):
pass
@abstractmethod
def reset(self):
pass
class PairwiseScore(EvaluatorScore, ABC):
def __init__(self):
super().__init__()
self.individual_values = None
def get_value(self, groups=None, states=None):
"""
:param groups:
:return:
total_results: dict of kind {'mean': score mean, 'std': score std}
group_results: None, if groups is None;
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
"""
individual_values = torch.stack(states, dim=0).reshape(-1).cpu().numpy() if states is not None \
else self.individual_values
total_results = {
'mean': individual_values.mean(),
'std': individual_values.std()
}
if groups is None:
return total_results, None
group_results = dict()
grouping = get_groupings(groups)
for label, index in grouping.items():
group_scores = individual_values[index]
group_results[label] = {
'mean': group_scores.mean(),
'std': group_scores.std()
}
return total_results, group_results
def reset(self):
self.individual_values = []
class SSIMScore(PairwiseScore):
def __init__(self, window_size=11):
super().__init__()
self.score = SSIM(window_size=window_size, size_average=False).eval()
self.reset()
def forward(self, pred_batch, target_batch, mask=None):
batch_values = self.score(pred_batch, target_batch)
self.individual_values = np.hstack([
self.individual_values, batch_values.detach().cpu().numpy()
])
return batch_values
class LPIPSScore(PairwiseScore):
def __init__(self, model='net-lin', net='vgg', model_path=None, use_gpu=True):
super().__init__()
self.score = PerceptualLoss(model=model, net=net, model_path=model_path,
use_gpu=use_gpu, spatial=False).eval()
self.reset()
def forward(self, pred_batch, target_batch, mask=None):
batch_values = self.score(pred_batch, target_batch).flatten()
self.individual_values = np.hstack([
self.individual_values, batch_values.detach().cpu().numpy()
])
return batch_values
def fid_calculate_activation_statistics(act):
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6):
mu1, sigma1 = fid_calculate_activation_statistics(activations_pred)
mu2, sigma2 = fid_calculate_activation_statistics(activations_target)
diff = mu1 - mu2
# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
LOGGER.warning(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
# if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) +
np.trace(sigma2) - 2 * tr_covmean)
class FIDScore(EvaluatorScore):
def __init__(self, dims=2048, eps=1e-6):
LOGGER.info("FIDscore init called")
super().__init__()
if getattr(FIDScore, '_MODEL', None) is None:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
FIDScore._MODEL = InceptionV3([block_idx]).eval()
self.model = FIDScore._MODEL
self.eps = eps
self.reset()
LOGGER.info("FIDscore init done")
def forward(self, pred_batch, target_batch, mask=None):
activations_pred = self._get_activations(pred_batch)
activations_target = self._get_activations(target_batch)
self.activations_pred.append(activations_pred.detach().cpu())
self.activations_target.append(activations_target.detach().cpu())
return activations_pred, activations_target
def get_value(self, groups=None, states=None):
LOGGER.info("FIDscore get_value called")
activations_pred, activations_target = zip(*states) if states is not None \
else (self.activations_pred, self.activations_target)
activations_pred = torch.cat(activations_pred).cpu().numpy()
activations_target = torch.cat(activations_target).cpu().numpy()
total_distance = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps)
total_results = dict(mean=total_distance)
if groups is None:
group_results = None
else:
group_results = dict()
grouping = get_groupings(groups)
for label, index in grouping.items():
if len(index) > 1:
group_distance = calculate_frechet_distance(activations_pred[index], activations_target[index],
eps=self.eps)
group_results[label] = dict(mean=group_distance)
else:
group_results[label] = dict(mean=float('nan'))
self.reset()
LOGGER.info("FIDscore get_value done")
return total_results, group_results
def reset(self):
self.activations_pred = []
self.activations_target = []
def _get_activations(self, batch):
activations = self.model(batch)[0]
if activations.shape[2] != 1 or activations.shape[3] != 1:
assert False, \
'We should not have got here, because Inception always scales inputs to 299x299'
# activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1))
activations = activations.squeeze(-1).squeeze(-1)
return activations
class SegmentationAwareScore(EvaluatorScore):
def __init__(self, weights_path):
super().__init__()
self.segm_network = SegmentationModule(weights_path=weights_path, use_default_normalization=True).eval()
self.target_class_freq_by_image_total = []
self.target_class_freq_by_image_mask = []
self.pred_class_freq_by_image_mask = []
def forward(self, pred_batch, target_batch, mask):
pred_segm_flat = self.segm_network.predict(pred_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy()
target_segm_flat = self.segm_network.predict(target_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy()
mask_flat = (mask.view(mask.shape[0], -1) > 0.5).detach().cpu().numpy()
batch_target_class_freq_total = []
batch_target_class_freq_mask = []
batch_pred_class_freq_mask = []
for cur_pred_segm, cur_target_segm, cur_mask in zip(pred_segm_flat, target_segm_flat, mask_flat):
cur_target_class_freq_total = np.bincount(cur_target_segm, minlength=NUM_CLASS)[None, ...]
cur_target_class_freq_mask = np.bincount(cur_target_segm[cur_mask], minlength=NUM_CLASS)[None, ...]
cur_pred_class_freq_mask = np.bincount(cur_pred_segm[cur_mask], minlength=NUM_CLASS)[None, ...]
self.target_class_freq_by_image_total.append(cur_target_class_freq_total)
self.target_class_freq_by_image_mask.append(cur_target_class_freq_mask)
self.pred_class_freq_by_image_mask.append(cur_pred_class_freq_mask)
batch_target_class_freq_total.append(cur_target_class_freq_total)
batch_target_class_freq_mask.append(cur_target_class_freq_mask)
batch_pred_class_freq_mask.append(cur_pred_class_freq_mask)
batch_target_class_freq_total = np.concatenate(batch_target_class_freq_total, axis=0)
batch_target_class_freq_mask = np.concatenate(batch_target_class_freq_mask, axis=0)
batch_pred_class_freq_mask = np.concatenate(batch_pred_class_freq_mask, axis=0)
return batch_target_class_freq_total, batch_target_class_freq_mask, batch_pred_class_freq_mask
def reset(self):
super().reset()
self.target_class_freq_by_image_total = []
self.target_class_freq_by_image_mask = []
self.pred_class_freq_by_image_mask = []
def distribute_values_to_classes(target_class_freq_by_image_mask, values, idx2name):
assert target_class_freq_by_image_mask.ndim == 2 and target_class_freq_by_image_mask.shape[0] == values.shape[0]
total_class_freq = target_class_freq_by_image_mask.sum(0)
distr_values = (target_class_freq_by_image_mask * values[..., None]).sum(0)
result = distr_values / (total_class_freq + 1e-3)
return {idx2name[i]: val for i, val in enumerate(result) if total_class_freq[i] > 0}
def get_segmentation_idx2name():
return {i - 1: name for i, name in segm_options['classes'].set_index('Idx', drop=True)['Name'].to_dict().items()}
class SegmentationAwarePairwiseScore(SegmentationAwareScore):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.individual_values = []
self.segm_idx2name = get_segmentation_idx2name()
def forward(self, pred_batch, target_batch, mask):
cur_class_stats = super().forward(pred_batch, target_batch, mask)
score_values = self.calc_score(pred_batch, target_batch, mask)
self.individual_values.append(score_values)
return cur_class_stats + (score_values,)
@abstractmethod
def calc_score(self, pred_batch, target_batch, mask):
raise NotImplementedError()
def get_value(self, groups=None, states=None):
"""
:param groups:
:return:
total_results: dict of kind {'mean': score mean, 'std': score std}
group_results: None, if groups is None;
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
"""
if states is not None:
(target_class_freq_by_image_total,
target_class_freq_by_image_mask,
pred_class_freq_by_image_mask,
individual_values) = states
else:
target_class_freq_by_image_total = self.target_class_freq_by_image_total
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
individual_values = self.individual_values
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
individual_values = np.concatenate(individual_values, axis=0)
total_results = {
'mean': individual_values.mean(),
'std': individual_values.std(),
**distribute_values_to_classes(target_class_freq_by_image_mask, individual_values, self.segm_idx2name)
}
if groups is None:
return total_results, None
group_results = dict()
grouping = get_groupings(groups)
for label, index in grouping.items():
group_class_freq = target_class_freq_by_image_mask[index]
group_scores = individual_values[index]
group_results[label] = {
'mean': group_scores.mean(),
'std': group_scores.std(),
** distribute_values_to_classes(group_class_freq, group_scores, self.segm_idx2name)
}
return total_results, group_results
def reset(self):
super().reset()
self.individual_values = []
class SegmentationClassStats(SegmentationAwarePairwiseScore):
def calc_score(self, pred_batch, target_batch, mask):
return 0
def get_value(self, groups=None, states=None):
"""
:param groups:
:return:
total_results: dict of kind {'mean': score mean, 'std': score std}
group_results: None, if groups is None;
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
"""
if states is not None:
(target_class_freq_by_image_total,
target_class_freq_by_image_mask,
pred_class_freq_by_image_mask,
_) = states
else:
target_class_freq_by_image_total = self.target_class_freq_by_image_total
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
target_class_freq_by_image_total_marginal = target_class_freq_by_image_total.sum(0).astype('float32')
target_class_freq_by_image_total_marginal /= target_class_freq_by_image_total_marginal.sum()
target_class_freq_by_image_mask_marginal = target_class_freq_by_image_mask.sum(0).astype('float32')
target_class_freq_by_image_mask_marginal /= target_class_freq_by_image_mask_marginal.sum()
pred_class_freq_diff = (pred_class_freq_by_image_mask - target_class_freq_by_image_mask).sum(0) / (target_class_freq_by_image_mask.sum(0) + 1e-3)
total_results = dict()
total_results.update({f'total_freq/{self.segm_idx2name[i]}': v
for i, v in enumerate(target_class_freq_by_image_total_marginal)
if v > 0})
total_results.update({f'mask_freq/{self.segm_idx2name[i]}': v
for i, v in enumerate(target_class_freq_by_image_mask_marginal)
if v > 0})
total_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v
for i, v in enumerate(pred_class_freq_diff)
if target_class_freq_by_image_total_marginal[i] > 0})
if groups is None:
return total_results, None
group_results = dict()
grouping = get_groupings(groups)
for label, index in grouping.items():
group_target_class_freq_by_image_total = target_class_freq_by_image_total[index]
group_target_class_freq_by_image_mask = target_class_freq_by_image_mask[index]
group_pred_class_freq_by_image_mask = pred_class_freq_by_image_mask[index]
group_target_class_freq_by_image_total_marginal = group_target_class_freq_by_image_total.sum(0).astype('float32')
group_target_class_freq_by_image_total_marginal /= group_target_class_freq_by_image_total_marginal.sum()
group_target_class_freq_by_image_mask_marginal = group_target_class_freq_by_image_mask.sum(0).astype('float32')
group_target_class_freq_by_image_mask_marginal /= group_target_class_freq_by_image_mask_marginal.sum()
group_pred_class_freq_diff = (group_pred_class_freq_by_image_mask - group_target_class_freq_by_image_mask).sum(0) / (
group_target_class_freq_by_image_mask.sum(0) + 1e-3)
cur_group_results = dict()
cur_group_results.update({f'total_freq/{self.segm_idx2name[i]}': v
for i, v in enumerate(group_target_class_freq_by_image_total_marginal)
if v > 0})
cur_group_results.update({f'mask_freq/{self.segm_idx2name[i]}': v
for i, v in enumerate(group_target_class_freq_by_image_mask_marginal)
if v > 0})
cur_group_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v
for i, v in enumerate(group_pred_class_freq_diff)
if group_target_class_freq_by_image_total_marginal[i] > 0})
group_results[label] = cur_group_results
return total_results, group_results
class SegmentationAwareSSIM(SegmentationAwarePairwiseScore):
def __init__(self, *args, window_size=11, **kwargs):
super().__init__(*args, **kwargs)
self.score_impl = SSIM(window_size=window_size, size_average=False).eval()
def calc_score(self, pred_batch, target_batch, mask):
return self.score_impl(pred_batch, target_batch).detach().cpu().numpy()
class SegmentationAwareLPIPS(SegmentationAwarePairwiseScore):
def __init__(self, *args, model='net-lin', net='vgg', model_path=None, use_gpu=True, **kwargs):
super().__init__(*args, **kwargs)
self.score_impl = PerceptualLoss(model=model, net=net, model_path=model_path,
use_gpu=use_gpu, spatial=False).eval()
def calc_score(self, pred_batch, target_batch, mask):
return self.score_impl(pred_batch, target_batch).flatten().detach().cpu().numpy()
def calculade_fid_no_img(img_i, activations_pred, activations_target, eps=1e-6):
activations_pred = activations_pred.copy()
activations_pred[img_i] = activations_target[img_i]
return calculate_frechet_distance(activations_pred, activations_target, eps=eps)
class SegmentationAwareFID(SegmentationAwarePairwiseScore):
def __init__(self, *args, dims=2048, eps=1e-6, n_jobs=-1, **kwargs):
super().__init__(*args, **kwargs)
if getattr(FIDScore, '_MODEL', None) is None:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
FIDScore._MODEL = InceptionV3([block_idx]).eval()
self.model = FIDScore._MODEL
self.eps = eps
self.n_jobs = n_jobs
def calc_score(self, pred_batch, target_batch, mask):
activations_pred = self._get_activations(pred_batch)
activations_target = self._get_activations(target_batch)
return activations_pred, activations_target
def get_value(self, groups=None, states=None):
"""
:param groups:
:return:
total_results: dict of kind {'mean': score mean, 'std': score std}
group_results: None, if groups is None;
else dict {group_idx: {'mean': score mean among group, 'std': score std among group}}
"""
if states is not None:
(target_class_freq_by_image_total,
target_class_freq_by_image_mask,
pred_class_freq_by_image_mask,
activation_pairs) = states
else:
target_class_freq_by_image_total = self.target_class_freq_by_image_total
target_class_freq_by_image_mask = self.target_class_freq_by_image_mask
pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask
activation_pairs = self.individual_values
target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0)
target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0)
pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0)
activations_pred, activations_target = zip(*activation_pairs)
activations_pred = np.concatenate(activations_pred, axis=0)
activations_target = np.concatenate(activations_target, axis=0)
total_results = {
'mean': calculate_frechet_distance(activations_pred, activations_target, eps=self.eps),
'std': 0,
**self.distribute_fid_to_classes(target_class_freq_by_image_mask, activations_pred, activations_target)
}
if groups is None:
return total_results, None
group_results = dict()
grouping = get_groupings(groups)
for label, index in grouping.items():
if len(index) > 1:
group_activations_pred = activations_pred[index]
group_activations_target = activations_target[index]
group_class_freq = target_class_freq_by_image_mask[index]
group_results[label] = {
'mean': calculate_frechet_distance(group_activations_pred, group_activations_target, eps=self.eps),
'std': 0,
**self.distribute_fid_to_classes(group_class_freq,
group_activations_pred,
group_activations_target)
}
else:
group_results[label] = dict(mean=float('nan'), std=0)
return total_results, group_results
def distribute_fid_to_classes(self, class_freq, activations_pred, activations_target):
real_fid = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps)
fid_no_images = Parallel(n_jobs=self.n_jobs)(
delayed(calculade_fid_no_img)(img_i, activations_pred, activations_target, eps=self.eps)
for img_i in range(activations_pred.shape[0])
)
errors = real_fid - fid_no_images
return distribute_values_to_classes(class_freq, errors, self.segm_idx2name)
def _get_activations(self, batch):
activations = self.model(batch)[0]
if activations.shape[2] != 1 or activations.shape[3] != 1:
activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1))
activations = activations.squeeze(-1).squeeze(-1).detach().cpu().numpy()
return activations