Rerender / src /controller.py
Anonymous-sub's picture
merge (#1)
251e479
raw history blame
No virus
4.75 kB
import gc
import torch
import torch.nn.functional as F
from flow.flow_utils import flow_warp
# AdaIn
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
class AttentionControl():
def __init__(self, inner_strength, mask_period, cross_period, ada_period,
warp_period):
self.step_store = self.get_empty_store()
self.cur_step = 0
self.total_step = 0
self.cur_index = 0
self.init_store = False
self.restore = False
self.update = False
self.flow = None
self.mask = None
self.restorex0 = False
self.updatex0 = False
self.inner_strength = inner_strength
self.cross_period = cross_period
self.mask_period = mask_period
self.ada_period = ada_period
self.warp_period = warp_period
@staticmethod
def get_empty_store():
return {
'first': [],
'previous': [],
'x0_previous': [],
'first_ada': []
}
def forward(self, context, is_cross: bool, place_in_unet: str):
cross_period = (self.total_step * self.cross_period[0],
self.total_step * self.cross_period[1])
if not is_cross and place_in_unet == 'up':
if self.init_store:
self.step_store['first'].append(context.detach())
self.step_store['previous'].append(context.detach())
if self.update:
tmp = context.clone().detach()
if self.restore and self.cur_step >= cross_period[0] and \
self.cur_step <= cross_period[1]:
context = torch.cat(
(self.step_store['first'][self.cur_index],
self.step_store['previous'][self.cur_index]),
dim=1).clone()
if self.update:
self.step_store['previous'][self.cur_index] = tmp
self.cur_index += 1
return context
def update_x0(self, x0):
if self.init_store:
self.step_store['x0_previous'].append(x0.detach())
style_mean, style_std = calc_mean_std(x0.detach())
self.step_store['first_ada'].append(style_mean.detach())
self.step_store['first_ada'].append(style_std.detach())
if self.updatex0:
tmp = x0.clone().detach()
if self.restorex0:
if self.cur_step >= self.total_step * self.ada_period[
0] and self.cur_step <= self.total_step * self.ada_period[
1]:
x0 = F.instance_norm(x0) * self.step_store['first_ada'][
2 * self.cur_step +
1] + self.step_store['first_ada'][2 * self.cur_step]
if self.cur_step >= self.total_step * self.warp_period[
0] and self.cur_step <= self.total_step * self.warp_period[
1]:
pre = self.step_store['x0_previous'][self.cur_step]
x0 = flow_warp(pre, self.flow, mode='nearest') * self.mask + (
1 - self.mask) * x0
if self.updatex0:
self.step_store['x0_previous'][self.cur_step] = tmp
return x0
def set_warp(self, flow, mask):
self.flow = flow.clone()
self.mask = mask.clone()
def __call__(self, context, is_cross: bool, place_in_unet: str):
context = self.forward(context, is_cross, place_in_unet)
return context
def set_step(self, step):
self.cur_step = step
def set_total_step(self, total_step):
self.total_step = total_step
self.cur_index = 0
def clear_store(self):
del self.step_store
torch.cuda.empty_cache()
gc.collect()
self.step_store = self.get_empty_store()
def set_task(self, task, restore_step=1.0):
self.init_store = False
self.restore = False
self.update = False
self.cur_index = 0
self.restore_step = restore_step
self.updatex0 = False
self.restorex0 = False
if 'initfirst' in task:
self.init_store = True
self.clear_store()
if 'updatestyle' in task:
self.update = True
if 'keepstyle' in task:
self.restore = True
if 'updatex0' in task:
self.updatex0 = True
if 'keepx0' in task:
self.restorex0 = True