File size: 4,753 Bytes
251e479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gc

import torch
import torch.nn.functional as F

from flow.flow_utils import flow_warp

# AdaIn


def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


class AttentionControl():

    def __init__(self, inner_strength, mask_period, cross_period, ada_period,
                 warp_period):
        self.step_store = self.get_empty_store()
        self.cur_step = 0
        self.total_step = 0
        self.cur_index = 0
        self.init_store = False
        self.restore = False
        self.update = False
        self.flow = None
        self.mask = None
        self.restorex0 = False
        self.updatex0 = False
        self.inner_strength = inner_strength
        self.cross_period = cross_period
        self.mask_period = mask_period
        self.ada_period = ada_period
        self.warp_period = warp_period

    @staticmethod
    def get_empty_store():
        return {
            'first': [],
            'previous': [],
            'x0_previous': [],
            'first_ada': []
        }

    def forward(self, context, is_cross: bool, place_in_unet: str):
        cross_period = (self.total_step * self.cross_period[0],
                        self.total_step * self.cross_period[1])
        if not is_cross and place_in_unet == 'up':
            if self.init_store:
                self.step_store['first'].append(context.detach())
                self.step_store['previous'].append(context.detach())
            if self.update:
                tmp = context.clone().detach()
            if self.restore and self.cur_step >= cross_period[0] and \
                    self.cur_step <= cross_period[1]:
                context = torch.cat(
                    (self.step_store['first'][self.cur_index],
                     self.step_store['previous'][self.cur_index]),
                    dim=1).clone()
            if self.update:
                self.step_store['previous'][self.cur_index] = tmp
            self.cur_index += 1
        return context

    def update_x0(self, x0):
        if self.init_store:
            self.step_store['x0_previous'].append(x0.detach())
            style_mean, style_std = calc_mean_std(x0.detach())
            self.step_store['first_ada'].append(style_mean.detach())
            self.step_store['first_ada'].append(style_std.detach())
        if self.updatex0:
            tmp = x0.clone().detach()
        if self.restorex0:
            if self.cur_step >= self.total_step * self.ada_period[
                    0] and self.cur_step <= self.total_step * self.ada_period[
                        1]:
                x0 = F.instance_norm(x0) * self.step_store['first_ada'][
                    2 * self.cur_step +
                    1] + self.step_store['first_ada'][2 * self.cur_step]
            if self.cur_step >= self.total_step * self.warp_period[
                    0] and self.cur_step <= self.total_step * self.warp_period[
                        1]:
                pre = self.step_store['x0_previous'][self.cur_step]
                x0 = flow_warp(pre, self.flow, mode='nearest') * self.mask + (
                    1 - self.mask) * x0
        if self.updatex0:
            self.step_store['x0_previous'][self.cur_step] = tmp
        return x0

    def set_warp(self, flow, mask):
        self.flow = flow.clone()
        self.mask = mask.clone()

    def __call__(self, context, is_cross: bool, place_in_unet: str):
        context = self.forward(context, is_cross, place_in_unet)
        return context

    def set_step(self, step):
        self.cur_step = step

    def set_total_step(self, total_step):
        self.total_step = total_step
        self.cur_index = 0

    def clear_store(self):
        del self.step_store
        torch.cuda.empty_cache()
        gc.collect()
        self.step_store = self.get_empty_store()

    def set_task(self, task, restore_step=1.0):
        self.init_store = False
        self.restore = False
        self.update = False
        self.cur_index = 0
        self.restore_step = restore_step
        self.updatex0 = False
        self.restorex0 = False
        if 'initfirst' in task:
            self.init_store = True
            self.clear_store()
        if 'updatestyle' in task:
            self.update = True
        if 'keepstyle' in task:
            self.restore = True
        if 'updatex0' in task:
            self.updatex0 = True
        if 'keepx0' in task:
            self.restorex0 = True