File size: 6,815 Bytes
748c921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math


class WaveletTransform(nn.Module):

    def __init__(self, patch_size: int, inverse: bool = False):
        '''
        `patchwise` in forward/invert makes *no difference*; the result
        is numerically identical either way. It's still enabled by default
        in case we pass in a non-square image, which may not be equivalent.
        `reshape` is pretty much useless.
        TODO: Clean up these options.
        '''
        super().__init__()
        self.patch_size = patch_size
        self.inverse = inverse
        # From https://github.com/NVIDIA/Cosmos-Tokenizer/blob/3584ae752ce8ebdbe06a420bf60d7513c0e878cc/cosmos_tokenizer/modules/patching.py#L33
        self.haar = torch.tensor([0.7071067811865476, 0.7071067811865476])
        self.arange = torch.arange(len(self.haar))
        self.steps = int(math.log2(self.patch_size))

    def num_transformed_channels(self, in_channels: int = 3) -> int:
        '''
        Returns the number of channels to expect in the transformed image
        given the channels in the input image.
        '''
        return in_channels * (4 ** self.steps)
        

    def forward(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor:
        if self.inverse:
            return self.invert(x, patchwise=patchwise, from_reshaped=reshape)
        else:
            return self.transform(x, patchwise=patchwise, reshape=reshape)


    def transform(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor:
        '''
        ### Parameters:
            `x`: ImageNet-normalized images with shape (B C H W)
            `patchwise`: Whether to compute independently on patches
            `reshape`: Reshape the results to match the input HxW
        ### Returns:
            If `reshape`, returns (B C H W)
            otherwise, returns (B C*patch_size**2 H/patch_size W/patch_size)
        '''
        p = self.patch_size
        if patchwise:
            # Place patches into batch dimension
            # (B C H W) -> (B*L C H/root(L), W/root(L))
            b, c, h, w = x.shape
            init_b = b
            # (B C H W) -> (B C LH LW P P)
            x = x.reshape(b, c, h//p, p, w//p, p).moveaxis(4,3)
            # (B C LH LW P P) -> (B' C P P)
            x = x.moveaxis(1,3).reshape(-1, c, p, p)

        for _ in range(self.steps):
            x = self.dwt(x)

        if patchwise:
            # Extract patches from batch dimension
            # (B' C' 1 1) -> (B LH LW C') -> (B C' LH LW)
            x = x.reshape(init_b, h//p, w//p, -1).moveaxis(3,1)
        if reshape:
            # (B C*patch_size**2 H/patch_size W/patch_size) -> (B C H W)
            b, cp2, hdp, wdp = x.shape
            c, h, w = cp2//(p**2), hdp*p, wdp*p
            x = x.reshape(b, p, p, c, hdp, wdp)
            x = x.moveaxis(3,1).moveaxis(3,4).reshape(b, c, h, w).contiguous()
        return x
    
    def invert(self, x: torch.Tensor, patchwise: bool = True, from_reshaped: bool = False) -> torch.Tensor:
        '''
        ### Parameters:
            `x`: Wavelet-space input of either (B C H W) (when `from_reshaped=True`) or
                (B C*patch_size**2 H/patch_size W/patch_size)
            `patchwise`: Whether to compute independently on patches
            `from_reshaped`: Determines the shape of `x`; should match the value of `reshape`
                used when calling `forward`
        '''
        p = self.patch_size
        if from_reshaped:
            # (B C H W) -> (B C*patch_size**2 H/patch_size W/patch_size)
            b, c, h, w = x.shape
            cp2, hdp, wdp = c*self.patch_size**2, h//self.patch_size, w//self.patch_size
            x = x.reshape(b, c, self.patch_size, hdp, self.patch_size, wdp)
            x = x.moveaxis(4,3).moveaxis(1,3).reshape(b, cp2, hdp, wdp)
        if patchwise:
            # Put patches into batch dimension
            # (B C' LH LW) -> (B LH LW C') -> (B' C' 1 1)
            init_b, lh, lw = x.shape[0], x.shape[2], x.shape[3]
            x = x.moveaxis(1,3).reshape(-1, x.shape[1], 1, 1)

        for _ in range(self.steps):
            x = self.idwt(x)

        if patchwise:
            # Extract patches from batch dimension and expand
            # (B' C P P) -> (B C LH LW P P)
            x = x.reshape(init_b, lh, lw, *x.shape[1:]).moveaxis(3,1)
            # (B C LH LW P P) -> (B C H W)
            x = x.moveaxis(3,4).reshape(*x.shape[:2], lh*p, lw*p)
        return x


    def dwt(self, x: torch.Tensor):
        dtype = x.dtype
        h = self.haar

        n = h.shape[0]
        g = x.shape[1]
        hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
        hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(g, 1, 1)
        hh = hh.to(device=x.device, dtype=dtype)
        hl = hl.to(device=x.device, dtype=dtype)

        x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode='reflect').to(dtype)
        xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
        xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
        xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
        xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
        xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
        xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))

        return 0.5 * torch.cat([xll, xlh, xhl, xhh], dim=1)
    
    
    def idwt(self, x: torch.Tensor):
        dtype = x.dtype
        h = self.haar
        n = h.shape[0]

        g = x.shape[1] // 4
        hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
        hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(g, 1, 1)
        hh = hh.to(device=x.device, dtype=dtype)
        hl = hl.to(device=x.device, dtype=dtype)

        xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)

        # Inverse transform.
        yl = torch.nn.functional.conv_transpose2d(
            xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
        )
        yl += torch.nn.functional.conv_transpose2d(
            xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
        )
        yh = torch.nn.functional.conv_transpose2d(
            xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
        )
        yh += torch.nn.functional.conv_transpose2d(
            xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)
        )
        y = torch.nn.functional.conv_transpose2d(
            yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
        )
        y += torch.nn.functional.conv_transpose2d(
            yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)
        )

        return 2.0 * y