import bisect import torch import torch.nn.functional as F import lpips perceptual_loss = lpips.LPIPS() def distance(img_a, img_b): return perceptual_loss(img_a, img_b).item() # return F.mse_loss(img_a, img_b).item() class AlphaScheduler: def __init__(self): ... def from_imgs(self, imgs): self.__num_values = len(imgs) self.__values = [0] for i in range(self.__num_values - 1): dis = distance(imgs[i], imgs[i + 1]) self.__values.append(dis) self.__values[i + 1] += self.__values[i] for i in range(self.__num_values): self.__values[i] /= self.__values[-1] def save(self, filename): torch.save(torch.tensor(self.__values), filename) def load(self, filename): self.__values = torch.load(filename).tolist() self.__num_values = len(self.__values) def get_x(self, y): assert y >= 0 and y <= 1 id = bisect.bisect_left(self.__values, y) id -= 1 if id < 0: id = 0 yl = self.__values[id] yr = self.__values[id + 1] xl = id * (1 / (self.__num_values - 1)) xr = (id + 1) * (1 / (self.__num_values - 1)) x = (y - yl) / (yr - yl) * (xr - xl) + xl return x def get_list(self, len=None): if len is None: len = self.__num_values ys = torch.linspace(0, 1, len) res = [self.get_x(y) for y in ys] return res