cerulianx commited on
Commit
4e417e5
1 Parent(s): e902303

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +59 -0
utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import attr
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ logit_laplace_eps: float = 0.1
9
+
10
+ @attr.s(eq=False)
11
+ class Conv2d(nn.Module):
12
+ n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
13
+ n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
14
+ kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)
15
+
16
+ use_float16: bool = attr.ib(default=True)
17
+ device: torch.device = attr.ib(default=torch.device('cpu'))
18
+ requires_grad: bool = attr.ib(default=False)
19
+
20
+ def __attrs_post_init__(self) -> None:
21
+ super().__init__()
22
+
23
+ w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,
24
+ device=self.device, requires_grad=self.requires_grad)
25
+ w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))
26
+
27
+ b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,
28
+ requires_grad=self.requires_grad)
29
+ self.w, self.b = nn.Parameter(w), nn.Parameter(b)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ if self.use_float16 and 'cuda' in self.w.device.type:
33
+ if x.dtype != torch.float16:
34
+ x = x.half()
35
+
36
+ w, b = self.w.half(), self.b.half()
37
+ else:
38
+ if x.dtype != torch.float32:
39
+ x = x.float()
40
+
41
+ w, b = self.w, self.b
42
+
43
+ return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
44
+
45
+ def map_pixels(x: torch.Tensor) -> torch.Tensor:
46
+ if len(x.shape) != 4:
47
+ raise ValueError('expected input to be 4d')
48
+ if x.dtype != torch.float:
49
+ raise ValueError('expected input to have type float')
50
+
51
+ return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps
52
+
53
+ def unmap_pixels(x: torch.Tensor) -> torch.Tensor:
54
+ if len(x.shape) != 4:
55
+ raise ValueError('expected input to be 4d')
56
+ if x.dtype != torch.float:
57
+ raise ValueError('expected input to have type float')
58
+
59
+ return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1)