Spaces:
Running
on
A10G
Running
on
A10G
File size: 4,753 Bytes
251e479 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gc
import torch
import torch.nn.functional as F
from flow.flow_utils import flow_warp
# AdaIn
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
class AttentionControl():
def __init__(self, inner_strength, mask_period, cross_period, ada_period,
warp_period):
self.step_store = self.get_empty_store()
self.cur_step = 0
self.total_step = 0
self.cur_index = 0
self.init_store = False
self.restore = False
self.update = False
self.flow = None
self.mask = None
self.restorex0 = False
self.updatex0 = False
self.inner_strength = inner_strength
self.cross_period = cross_period
self.mask_period = mask_period
self.ada_period = ada_period
self.warp_period = warp_period
@staticmethod
def get_empty_store():
return {
'first': [],
'previous': [],
'x0_previous': [],
'first_ada': []
}
def forward(self, context, is_cross: bool, place_in_unet: str):
cross_period = (self.total_step * self.cross_period[0],
self.total_step * self.cross_period[1])
if not is_cross and place_in_unet == 'up':
if self.init_store:
self.step_store['first'].append(context.detach())
self.step_store['previous'].append(context.detach())
if self.update:
tmp = context.clone().detach()
if self.restore and self.cur_step >= cross_period[0] and \
self.cur_step <= cross_period[1]:
context = torch.cat(
(self.step_store['first'][self.cur_index],
self.step_store['previous'][self.cur_index]),
dim=1).clone()
if self.update:
self.step_store['previous'][self.cur_index] = tmp
self.cur_index += 1
return context
def update_x0(self, x0):
if self.init_store:
self.step_store['x0_previous'].append(x0.detach())
style_mean, style_std = calc_mean_std(x0.detach())
self.step_store['first_ada'].append(style_mean.detach())
self.step_store['first_ada'].append(style_std.detach())
if self.updatex0:
tmp = x0.clone().detach()
if self.restorex0:
if self.cur_step >= self.total_step * self.ada_period[
0] and self.cur_step <= self.total_step * self.ada_period[
1]:
x0 = F.instance_norm(x0) * self.step_store['first_ada'][
2 * self.cur_step +
1] + self.step_store['first_ada'][2 * self.cur_step]
if self.cur_step >= self.total_step * self.warp_period[
0] and self.cur_step <= self.total_step * self.warp_period[
1]:
pre = self.step_store['x0_previous'][self.cur_step]
x0 = flow_warp(pre, self.flow, mode='nearest') * self.mask + (
1 - self.mask) * x0
if self.updatex0:
self.step_store['x0_previous'][self.cur_step] = tmp
return x0
def set_warp(self, flow, mask):
self.flow = flow.clone()
self.mask = mask.clone()
def __call__(self, context, is_cross: bool, place_in_unet: str):
context = self.forward(context, is_cross, place_in_unet)
return context
def set_step(self, step):
self.cur_step = step
def set_total_step(self, total_step):
self.total_step = total_step
self.cur_index = 0
def clear_store(self):
del self.step_store
torch.cuda.empty_cache()
gc.collect()
self.step_store = self.get_empty_store()
def set_task(self, task, restore_step=1.0):
self.init_store = False
self.restore = False
self.update = False
self.cur_index = 0
self.restore_step = restore_step
self.updatex0 = False
self.restorex0 = False
if 'initfirst' in task:
self.init_store = True
self.clear_store()
if 'updatestyle' in task:
self.update = True
if 'keepstyle' in task:
self.restore = True
if 'updatex0' in task:
self.updatex0 = True
if 'keepx0' in task:
self.restorex0 = True
|