Spaces:
Running
on
A10G
Running
on
A10G
import gc | |
import torch | |
import torch.nn.functional as F | |
from flow.flow_utils import flow_warp | |
# AdaIn | |
def calc_mean_std(feat, eps=1e-5): | |
# eps is a small value added to the variance to avoid divide-by-zero. | |
size = feat.size() | |
assert (len(size) == 4) | |
N, C = size[:2] | |
feat_var = feat.view(N, C, -1).var(dim=2) + eps | |
feat_std = feat_var.sqrt().view(N, C, 1, 1) | |
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) | |
return feat_mean, feat_std | |
class AttentionControl(): | |
def __init__(self, inner_strength, mask_period, cross_period, ada_period, | |
warp_period): | |
self.step_store = self.get_empty_store() | |
self.cur_step = 0 | |
self.total_step = 0 | |
self.cur_index = 0 | |
self.init_store = False | |
self.restore = False | |
self.update = False | |
self.flow = None | |
self.mask = None | |
self.restorex0 = False | |
self.updatex0 = False | |
self.inner_strength = inner_strength | |
self.cross_period = cross_period | |
self.mask_period = mask_period | |
self.ada_period = ada_period | |
self.warp_period = warp_period | |
def get_empty_store(): | |
return { | |
'first': [], | |
'previous': [], | |
'x0_previous': [], | |
'first_ada': [] | |
} | |
def forward(self, context, is_cross: bool, place_in_unet: str): | |
cross_period = (self.total_step * self.cross_period[0], | |
self.total_step * self.cross_period[1]) | |
if not is_cross and place_in_unet == 'up': | |
if self.init_store: | |
self.step_store['first'].append(context.detach()) | |
self.step_store['previous'].append(context.detach()) | |
if self.update: | |
tmp = context.clone().detach() | |
if self.restore and self.cur_step >= cross_period[0] and \ | |
self.cur_step <= cross_period[1]: | |
context = torch.cat( | |
(self.step_store['first'][self.cur_index], | |
self.step_store['previous'][self.cur_index]), | |
dim=1).clone() | |
if self.update: | |
self.step_store['previous'][self.cur_index] = tmp | |
self.cur_index += 1 | |
return context | |
def update_x0(self, x0): | |
if self.init_store: | |
self.step_store['x0_previous'].append(x0.detach()) | |
style_mean, style_std = calc_mean_std(x0.detach()) | |
self.step_store['first_ada'].append(style_mean.detach()) | |
self.step_store['first_ada'].append(style_std.detach()) | |
if self.updatex0: | |
tmp = x0.clone().detach() | |
if self.restorex0: | |
if self.cur_step >= self.total_step * self.ada_period[ | |
0] and self.cur_step <= self.total_step * self.ada_period[ | |
1]: | |
x0 = F.instance_norm(x0) * self.step_store['first_ada'][ | |
2 * self.cur_step + | |
1] + self.step_store['first_ada'][2 * self.cur_step] | |
if self.cur_step >= self.total_step * self.warp_period[ | |
0] and self.cur_step <= self.total_step * self.warp_period[ | |
1]: | |
pre = self.step_store['x0_previous'][self.cur_step] | |
x0 = flow_warp(pre, self.flow, mode='nearest') * self.mask + ( | |
1 - self.mask) * x0 | |
if self.updatex0: | |
self.step_store['x0_previous'][self.cur_step] = tmp | |
return x0 | |
def set_warp(self, flow, mask): | |
self.flow = flow.clone() | |
self.mask = mask.clone() | |
def __call__(self, context, is_cross: bool, place_in_unet: str): | |
context = self.forward(context, is_cross, place_in_unet) | |
return context | |
def set_step(self, step): | |
self.cur_step = step | |
def set_total_step(self, total_step): | |
self.total_step = total_step | |
self.cur_index = 0 | |
def clear_store(self): | |
del self.step_store | |
torch.cuda.empty_cache() | |
gc.collect() | |
self.step_store = self.get_empty_store() | |
def set_task(self, task, restore_step=1.0): | |
self.init_store = False | |
self.restore = False | |
self.update = False | |
self.cur_index = 0 | |
self.restore_step = restore_step | |
self.updatex0 = False | |
self.restorex0 = False | |
if 'initfirst' in task: | |
self.init_store = True | |
self.clear_store() | |
if 'updatestyle' in task: | |
self.update = True | |
if 'keepstyle' in task: | |
self.restore = True | |
if 'updatex0' in task: | |
self.updatex0 = True | |
if 'keepx0' in task: | |
self.restorex0 = True | |