File size: 5,499 Bytes
eb21b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import math
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from collections  import OrderedDict

from transformers import PreTrainedModel

from .configuration_dalle imoprt DallEConfig


class Conv2d(nn.Module):
    def __init__(self, n_in, n_out, kw, config, use_float16=True):
        super().__init__()
        
        assert n_in >= 1
        assert n_out >= 1
        assert kw >= 1 and kw % 2 == 1
        
        self.n_in = n_in
        self.n_out = n_out
        self.kw = kw
        self.config = config
        self.use_float16 = use_float16
        w = torch.empty(
            (n_out, n_in, kw, kw),
            dtype=torch.float32,
            device=config.device,
            requires_grad=config.requires_grad,
        )
        w.normal_(std=1 / math.sqrt(n_in * kw ** 2))
        
        b = torch.zeros(
            (n_out,),
            dtype=torch.float32,
            device=config.device,
            requires_grad=config.requires_grad,
        )
        
        self.w = nn.Parameter(w)
        self.b = nn.Parameter(b)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_float16 and 'cuda' in self.w.device.type:
            if x.dtype != torch.float16:
                x = x.half()
            w, b = self.w.half(), self.b.half()
        else:
            if x.dtype != torch.float32:
                x = x.float()
            w, b = self.w, self.b
        return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
    
    def extra_repr(self):
        inner_repr = f"n_in={self.n_in}, n_out={self.n_out}, kw={self.kw}, "
        inner_repr += f"use_float16={self.use_float16}, "
        inner_repr += f"device={self.config.device}, "
        inner_repr += f"requires_grad={self.config.requires_grad}"
        return inner_repr
    
        
class EncoderBlock(nn.Module):
    def __init__(self, n_in, n_out, n_layers, config):
        super().__init__()
        
        assert n_in >= 1
        assert n_out >= 1 and n_out % 4 == 0
        assert n_layers >= 1
        
        self.n_in = n_in
        self.n_out = n_out
        self.n_hid = n_out // 4
        self.post_gain = 1 / (n_layers ** 2)
        
        if self.n_in != self.n_out:
            self.id_path = Conv2d(self.n_in, self.n_out, 1, config)
        else:
            self.id_path = nn.Identity()
            
        self.res_path = nn.Sequential(OrderedDict([
            ('relu_1', nn.ReLU()),
            ('conv_1', Conv2d(self.n_in, self.n_hid, 3, config)),
            ('relu_2', nn.ReLU()),
            ('conv_2', Conv2d(self.n_hid, self.n_hid, 3, config)),
            ('relu_3', nn.ReLU()),
            ('conv_3', Conv2d(self.n_hid, self.n_hid, 3, config)),
            ('relu_4', nn.ReLU()),
            ('conv_4', Conv2d(self.n_hid, self.n_out, 1, config)),
        ]))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.id_path(x) + self.post_gain * self.res_path(x)   


class DallEPreTrainedModel(PreTrainedModel):
    config_class = DallEConfig
    base_model_prefix="dalle"
    
    
class DallEEncoder(DallEPreTrainedModel):
    
    def __init__(self, config):
        super().__init__(config)
        blk_range = range(config.n_blk_per_group)
        n_layers = config.group_count * config.n_blk_per_group
        
        in_channels = config.input_channels
        n_hid = config.n_hid
        
        self.blocks = nn.Sequential(OrderedDict([
            ('input', Conv2d(in_channels, n_hid, 7, config)),
            ('group_1', nn.Sequential(OrderedDict([
                *[(f'block_{i + 1}', 
                   EncoderBlock(n_hid, n_hid, n_layers, config))
                  for i in blk_range],
                ('pool', nn.MaxPool2d(kernel_size=2)),
            ]))),
            ('group_2', nn.Sequential(OrderedDict([
                *[(f'block_{i + 1}', 
                   EncoderBlock(
                       n_hid if i == 0 else 2 * n_hid, 
                       2 * n_hid, n_layers, config))
                  for i in blk_range],
                ('pool', nn.MaxPool2d(kernel_size=2)),
            ]))),
            ('group_3', nn.Sequential(OrderedDict([
                *[(f'block_{i + 1}', 
                   EncoderBlock(
                       2 * n_hid if i == 0 else 4 * n_hid, 
                       4 * n_hid, n_layers, config))
                  for i in blk_range],
                ('pool', nn.MaxPool2d(kernel_size=2)),
            ]))),
            ('group_4', nn.Sequential(OrderedDict([
                *[(f'block_{i + 1}',
                   EncoderBlock(
                       4 * n_hid if i == 0 else 8 * n_hid, 
                       8 * n_hid, n_layers, config))
                  for i in blk_range],
            ]))),
            ('output', nn.Sequential(OrderedDict([
                ('relu', nn.ReLU()),
                ('conv', Conv2d(
                    8 * n_hid, config.vocab_size, 
                    1, config, use_float16=False)),
            ]))),
        ]))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if len(x.shape) != 4:
            raise ValueError(f'input shape {x.shape} is not 4d')
        if x.shape[1] != self.input_channels:
            raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')
        if x.dtype != torch.float32:
            raise ValueError('input must have dtype torch.float32')
            
        return self.blocks(x)