dorkai commited on
Commit
ba80407
1 Parent(s): 5cccec8

Upload 4 files

Browse files
Files changed (4) hide show
  1. __init__.py +18 -0
  2. decoder.py +94 -0
  3. encoder.py +93 -0
  4. utils.py +59 -0
__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io, requests
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from dall_e.encoder import Encoder
6
+ from dall_e.decoder import Decoder
7
+ from dall_e.utils import map_pixels, unmap_pixels
8
+
9
+ def load_model(path: str, device: torch.device = None) -> nn.Module:
10
+ if path.startswith('http://') or path.startswith('https://'):
11
+ resp = requests.get(path)
12
+ resp.raise_for_status()
13
+
14
+ with io.BytesIO(resp.content) as buf:
15
+ return torch.load(buf, map_location=device)
16
+ else:
17
+ with open(path, 'rb') as f:
18
+ return torch.load(f, map_location=device)
decoder.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import attr
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from collections import OrderedDict
9
+ from functools import partial
10
+ from dall_e.utils import Conv2d
11
+
12
+ @attr.s(eq=False, repr=False)
13
+ class DecoderBlock(nn.Module):
14
+ n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
15
+ n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
16
+ n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
17
+
18
+ device: torch.device = attr.ib(default=None)
19
+ requires_grad: bool = attr.ib(default=False)
20
+
21
+ def __attrs_post_init__(self) -> None:
22
+ super().__init__()
23
+ self.n_hid = self.n_out // 4
24
+ self.post_gain = 1 / (self.n_layers ** 2)
25
+
26
+ make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
27
+ self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
28
+ self.res_path = nn.Sequential(OrderedDict([
29
+ ('relu_1', nn.ReLU()),
30
+ ('conv_1', make_conv(self.n_in, self.n_hid, 1)),
31
+ ('relu_2', nn.ReLU()),
32
+ ('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
33
+ ('relu_3', nn.ReLU()),
34
+ ('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
35
+ ('relu_4', nn.ReLU()),
36
+ ('conv_4', make_conv(self.n_hid, self.n_out, 3)),]))
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ return self.id_path(x) + self.post_gain * self.res_path(x)
40
+
41
+ @attr.s(eq=False, repr=False)
42
+ class Decoder(nn.Module):
43
+ group_count: int = 4
44
+ n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8)
45
+ n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
46
+ n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
47
+ output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
48
+ vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
49
+
50
+ device: torch.device = attr.ib(default=torch.device('cpu'))
51
+ requires_grad: bool = attr.ib(default=False)
52
+ use_mixed_precision: bool = attr.ib(default=True)
53
+
54
+ def __attrs_post_init__(self) -> None:
55
+ super().__init__()
56
+
57
+ blk_range = range(self.n_blk_per_group)
58
+ n_layers = self.group_count * self.n_blk_per_group
59
+ make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
60
+ make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device,
61
+ requires_grad=self.requires_grad)
62
+
63
+ self.blocks = nn.Sequential(OrderedDict([
64
+ ('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)),
65
+ ('group_1', nn.Sequential(OrderedDict([
66
+ *[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
67
+ ('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
68
+ ]))),
69
+ ('group_2', nn.Sequential(OrderedDict([
70
+ *[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
71
+ ('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
72
+ ]))),
73
+ ('group_3', nn.Sequential(OrderedDict([
74
+ *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
75
+ ('upsample', nn.Upsample(scale_factor=2, mode='nearest')),
76
+ ]))),
77
+ ('group_4', nn.Sequential(OrderedDict([
78
+ *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
79
+ ]))),
80
+ ('output', nn.Sequential(OrderedDict([
81
+ ('relu', nn.ReLU()),
82
+ ('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)),
83
+ ]))),
84
+ ]))
85
+
86
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
87
+ if len(x.shape) != 4:
88
+ raise ValueError(f'input shape {x.shape} is not 4d')
89
+ if x.shape[1] != self.vocab_size:
90
+ raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}')
91
+ if x.dtype != torch.float32:
92
+ raise ValueError('input must have dtype torch.float32')
93
+
94
+ return self.blocks(x)
encoder.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import attr
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from collections import OrderedDict
9
+ from functools import partial
10
+ from dall_e.utils import Conv2d
11
+
12
+ @attr.s(eq=False, repr=False)
13
+ class EncoderBlock(nn.Module):
14
+ n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
15
+ n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0)
16
+ n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
17
+
18
+ device: torch.device = attr.ib(default=None)
19
+ requires_grad: bool = attr.ib(default=False)
20
+
21
+ def __attrs_post_init__(self) -> None:
22
+ super().__init__()
23
+ self.n_hid = self.n_out // 4
24
+ self.post_gain = 1 / (self.n_layers ** 2)
25
+
26
+ make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
27
+ self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity()
28
+ self.res_path = nn.Sequential(OrderedDict([
29
+ ('relu_1', nn.ReLU()),
30
+ ('conv_1', make_conv(self.n_in, self.n_hid, 3)),
31
+ ('relu_2', nn.ReLU()),
32
+ ('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
33
+ ('relu_3', nn.ReLU()),
34
+ ('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
35
+ ('relu_4', nn.ReLU()),
36
+ ('conv_4', make_conv(self.n_hid, self.n_out, 1)),]))
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ return self.id_path(x) + self.post_gain * self.res_path(x)
40
+
41
+ @attr.s(eq=False, repr=False)
42
+ class Encoder(nn.Module):
43
+ group_count: int = 4
44
+ n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
45
+ n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
46
+ input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
47
+ vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
48
+
49
+ device: torch.device = attr.ib(default=torch.device('cpu'))
50
+ requires_grad: bool = attr.ib(default=False)
51
+ use_mixed_precision: bool = attr.ib(default=True)
52
+
53
+ def __attrs_post_init__(self) -> None:
54
+ super().__init__()
55
+
56
+ blk_range = range(self.n_blk_per_group)
57
+ n_layers = self.group_count * self.n_blk_per_group
58
+ make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad)
59
+ make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device,
60
+ requires_grad=self.requires_grad)
61
+
62
+ self.blocks = nn.Sequential(OrderedDict([
63
+ ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),
64
+ ('group_1', nn.Sequential(OrderedDict([
65
+ *[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range],
66
+ ('pool', nn.MaxPool2d(kernel_size=2)),
67
+ ]))),
68
+ ('group_2', nn.Sequential(OrderedDict([
69
+ *[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range],
70
+ ('pool', nn.MaxPool2d(kernel_size=2)),
71
+ ]))),
72
+ ('group_3', nn.Sequential(OrderedDict([
73
+ *[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range],
74
+ ('pool', nn.MaxPool2d(kernel_size=2)),
75
+ ]))),
76
+ ('group_4', nn.Sequential(OrderedDict([
77
+ *[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range],
78
+ ]))),
79
+ ('output', nn.Sequential(OrderedDict([
80
+ ('relu', nn.ReLU()),
81
+ ('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)),
82
+ ]))),
83
+ ]))
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ if len(x.shape) != 4:
87
+ raise ValueError(f'input shape {x.shape} is not 4d')
88
+ if x.shape[1] != self.input_channels:
89
+ raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')
90
+ if x.dtype != torch.float32:
91
+ raise ValueError('input must have dtype torch.float32')
92
+
93
+ return self.blocks(x)
utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import attr
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ logit_laplace_eps: float = 0.1
9
+
10
+ @attr.s(eq=False)
11
+ class Conv2d(nn.Module):
12
+ n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
13
+ n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
14
+ kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)
15
+
16
+ use_float16: bool = attr.ib(default=True)
17
+ device: torch.device = attr.ib(default=torch.device('cpu'))
18
+ requires_grad: bool = attr.ib(default=False)
19
+
20
+ def __attrs_post_init__(self) -> None:
21
+ super().__init__()
22
+
23
+ w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,
24
+ device=self.device, requires_grad=self.requires_grad)
25
+ w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))
26
+
27
+ b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,
28
+ requires_grad=self.requires_grad)
29
+ self.w, self.b = nn.Parameter(w), nn.Parameter(b)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ if self.use_float16 and 'cuda' in self.w.device.type:
33
+ if x.dtype != torch.float16:
34
+ x = x.half()
35
+
36
+ w, b = self.w.half(), self.b.half()
37
+ else:
38
+ if x.dtype != torch.float32:
39
+ x = x.float()
40
+
41
+ w, b = self.w, self.b
42
+
43
+ return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
44
+
45
+ def map_pixels(x: torch.Tensor) -> torch.Tensor:
46
+ if len(x.shape) != 4:
47
+ raise ValueError('expected input to be 4d')
48
+ if x.dtype != torch.float:
49
+ raise ValueError('expected input to have type float')
50
+
51
+ return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps
52
+
53
+ def unmap_pixels(x: torch.Tensor) -> torch.Tensor:
54
+ if len(x.shape) != 4:
55
+ raise ValueError('expected input to be 4d')
56
+ if x.dtype != torch.float:
57
+ raise ValueError('expected input to have type float')
58
+
59
+ return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1)