52Hz commited on
Commit
9c28092
1 Parent(s): e3541a2

Create transform.py

Browse files
Files changed (1) hide show
  1. WT/transform.py +53 -0
WT/transform.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ def dwt_init(x):
5
+ x01 = x[:, :, 0::2, :] / 2
6
+ x02 = x[:, :, 1::2, :] / 2
7
+ x1 = x01[:, :, :, 0::2]
8
+ x2 = x02[:, :, :, 0::2]
9
+ x3 = x01[:, :, :, 1::2]
10
+ x4 = x02[:, :, :, 1::2]
11
+ x_LL = x1 + x2 + x3 + x4
12
+ x_HL = -x1 - x2 + x3 + x4
13
+ x_LH = -x1 + x2 - x3 + x4
14
+ x_HH = x1 - x2 - x3 + x4
15
+ # print(x_HH[:, 0, :, :])
16
+ return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
17
+
18
+ def iwt_init(x):
19
+ r = 2
20
+ in_batch, in_channel, in_height, in_width = x.size()
21
+ out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r ** 2)), r * in_height, r * in_width
22
+ x1 = x[:, 0:out_channel, :, :] / 2
23
+ x2 = x[:, out_channel:out_channel * 2, :, :] / 2
24
+ x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
25
+ x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
26
+ h = torch.zeros([out_batch, out_channel, out_height, out_width]).cuda() #
27
+
28
+ h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
29
+ h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
30
+ h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
31
+ h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
32
+
33
+ return h
34
+
35
+
36
+ class DWT(nn.Module):
37
+ def __init__(self):
38
+ super(DWT, self).__init__()
39
+ self.requires_grad = True
40
+
41
+ def forward(self, x):
42
+ return dwt_init(x)
43
+
44
+
45
+ class IWT(nn.Module):
46
+ def __init__(self):
47
+ super(IWT, self).__init__()
48
+ self.requires_grad = True
49
+
50
+ def forward(self, x):
51
+ return iwt_init(x)
52
+
53
+