import gc import torch import torch.nn.functional as F from flow.flow_utils import flow_warp # AdaIn def calc_mean_std(feat, eps=1e-5): # eps is a small value added to the variance to avoid divide-by-zero. size = feat.size() assert (len(size) == 4) N, C = size[:2] feat_var = feat.view(N, C, -1).var(dim=2) + eps feat_std = feat_var.sqrt().view(N, C, 1, 1) feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) return feat_mean, feat_std class AttentionControl(): def __init__(self, inner_strength, mask_period, cross_period, ada_period, warp_period): self.step_store = self.get_empty_store() self.cur_step = 0 self.total_step = 0 self.cur_index = 0 self.init_store = False self.restore = False self.update = False self.flow = None self.mask = None self.restorex0 = False self.updatex0 = False self.inner_strength = inner_strength self.cross_period = cross_period self.mask_period = mask_period self.ada_period = ada_period self.warp_period = warp_period @staticmethod def get_empty_store(): return { 'first': [], 'previous': [], 'x0_previous': [], 'first_ada': [] } def forward(self, context, is_cross: bool, place_in_unet: str): cross_period = (self.total_step * self.cross_period[0], self.total_step * self.cross_period[1]) if not is_cross and place_in_unet == 'up': if self.init_store: self.step_store['first'].append(context.detach()) self.step_store['previous'].append(context.detach()) if self.update: tmp = context.clone().detach() if self.restore and self.cur_step >= cross_period[0] and \ self.cur_step <= cross_period[1]: context = torch.cat( (self.step_store['first'][self.cur_index], self.step_store['previous'][self.cur_index]), dim=1).clone() if self.update: self.step_store['previous'][self.cur_index] = tmp self.cur_index += 1 return context def update_x0(self, x0): if self.init_store: self.step_store['x0_previous'].append(x0.detach()) style_mean, style_std = calc_mean_std(x0.detach()) self.step_store['first_ada'].append(style_mean.detach()) self.step_store['first_ada'].append(style_std.detach()) if self.updatex0: tmp = x0.clone().detach() if self.restorex0: if self.cur_step >= self.total_step * self.ada_period[ 0] and self.cur_step <= self.total_step * self.ada_period[ 1]: x0 = F.instance_norm(x0) * self.step_store['first_ada'][ 2 * self.cur_step + 1] + self.step_store['first_ada'][2 * self.cur_step] if self.cur_step >= self.total_step * self.warp_period[ 0] and self.cur_step <= self.total_step * self.warp_period[ 1]: pre = self.step_store['x0_previous'][self.cur_step] x0 = flow_warp(pre, self.flow, mode='nearest') * self.mask + ( 1 - self.mask) * x0 if self.updatex0: self.step_store['x0_previous'][self.cur_step] = tmp return x0 def set_warp(self, flow, mask): self.flow = flow.clone() self.mask = mask.clone() def __call__(self, context, is_cross: bool, place_in_unet: str): context = self.forward(context, is_cross, place_in_unet) return context def set_step(self, step): self.cur_step = step def set_total_step(self, total_step): self.total_step = total_step self.cur_index = 0 def clear_store(self): del self.step_store torch.cuda.empty_cache() gc.collect() self.step_store = self.get_empty_store() def set_task(self, task, restore_step=1.0): self.init_store = False self.restore = False self.update = False self.cur_index = 0 self.restore_step = restore_step self.updatex0 = False self.restorex0 = False if 'initfirst' in task: self.init_store = True self.clear_store() if 'updatestyle' in task: self.update = True if 'keepstyle' in task: self.restore = True if 'updatex0' in task: self.updatex0 = True if 'keepx0' in task: self.restorex0 = True