|
import math |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from collections import OrderedDict |
|
|
|
from transformers import PreTrainedModel |
|
|
|
from .configuration_dalle imoprt DallEConfig |
|
|
|
|
|
class Conv2d(nn.Module): |
|
def __init__(self, n_in, n_out, kw, config, use_float16=True): |
|
super().__init__() |
|
|
|
assert n_in >= 1 |
|
assert n_out >= 1 |
|
assert kw >= 1 and kw % 2 == 1 |
|
|
|
self.n_in = n_in |
|
self.n_out = n_out |
|
self.kw = kw |
|
self.config = config |
|
self.use_float16 = use_float16 |
|
w = torch.empty( |
|
(n_out, n_in, kw, kw), |
|
dtype=torch.float32, |
|
device=config.device, |
|
requires_grad=config.requires_grad, |
|
) |
|
w.normal_(std=1 / math.sqrt(n_in * kw ** 2)) |
|
|
|
b = torch.zeros( |
|
(n_out,), |
|
dtype=torch.float32, |
|
device=config.device, |
|
requires_grad=config.requires_grad, |
|
) |
|
|
|
self.w = nn.Parameter(w) |
|
self.b = nn.Parameter(b) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if self.use_float16 and 'cuda' in self.w.device.type: |
|
if x.dtype != torch.float16: |
|
x = x.half() |
|
w, b = self.w.half(), self.b.half() |
|
else: |
|
if x.dtype != torch.float32: |
|
x = x.float() |
|
w, b = self.w, self.b |
|
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) |
|
|
|
def extra_repr(self): |
|
inner_repr = f"n_in={self.n_in}, n_out={self.n_out}, kw={self.kw}, " |
|
inner_repr += f"use_float16={self.use_float16}, " |
|
inner_repr += f"device={self.config.device}, " |
|
inner_repr += f"requires_grad={self.config.requires_grad}" |
|
return inner_repr |
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
def __init__(self, n_in, n_out, n_layers, config): |
|
super().__init__() |
|
|
|
assert n_in >= 1 |
|
assert n_out >= 1 and n_out % 4 == 0 |
|
assert n_layers >= 1 |
|
|
|
self.n_in = n_in |
|
self.n_out = n_out |
|
self.n_hid = n_out // 4 |
|
self.post_gain = 1 / (n_layers ** 2) |
|
|
|
if self.n_in != self.n_out: |
|
self.id_path = Conv2d(self.n_in, self.n_out, 1, config) |
|
else: |
|
self.id_path = nn.Identity() |
|
|
|
self.res_path = nn.Sequential(OrderedDict([ |
|
('relu_1', nn.ReLU()), |
|
('conv_1', Conv2d(self.n_in, self.n_hid, 3, config)), |
|
('relu_2', nn.ReLU()), |
|
('conv_2', Conv2d(self.n_hid, self.n_hid, 3, config)), |
|
('relu_3', nn.ReLU()), |
|
('conv_3', Conv2d(self.n_hid, self.n_hid, 3, config)), |
|
('relu_4', nn.ReLU()), |
|
('conv_4', Conv2d(self.n_hid, self.n_out, 1, config)), |
|
])) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.id_path(x) + self.post_gain * self.res_path(x) |
|
|
|
|
|
class DallEPreTrainedModel(PreTrainedModel): |
|
config_class = DallEConfig |
|
base_model_prefix="dalle" |
|
|
|
|
|
class DallEEncoder(DallEPreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
blk_range = range(config.n_blk_per_group) |
|
n_layers = config.group_count * config.n_blk_per_group |
|
|
|
in_channels = config.input_channels |
|
n_hid = config.n_hid |
|
|
|
self.blocks = nn.Sequential(OrderedDict([ |
|
('input', Conv2d(in_channels, n_hid, 7, config)), |
|
('group_1', nn.Sequential(OrderedDict([ |
|
*[(f'block_{i + 1}', |
|
EncoderBlock(n_hid, n_hid, n_layers, config)) |
|
for i in blk_range], |
|
('pool', nn.MaxPool2d(kernel_size=2)), |
|
]))), |
|
('group_2', nn.Sequential(OrderedDict([ |
|
*[(f'block_{i + 1}', |
|
EncoderBlock( |
|
n_hid if i == 0 else 2 * n_hid, |
|
2 * n_hid, n_layers, config)) |
|
for i in blk_range], |
|
('pool', nn.MaxPool2d(kernel_size=2)), |
|
]))), |
|
('group_3', nn.Sequential(OrderedDict([ |
|
*[(f'block_{i + 1}', |
|
EncoderBlock( |
|
2 * n_hid if i == 0 else 4 * n_hid, |
|
4 * n_hid, n_layers, config)) |
|
for i in blk_range], |
|
('pool', nn.MaxPool2d(kernel_size=2)), |
|
]))), |
|
('group_4', nn.Sequential(OrderedDict([ |
|
*[(f'block_{i + 1}', |
|
EncoderBlock( |
|
4 * n_hid if i == 0 else 8 * n_hid, |
|
8 * n_hid, n_layers, config)) |
|
for i in blk_range], |
|
]))), |
|
('output', nn.Sequential(OrderedDict([ |
|
('relu', nn.ReLU()), |
|
('conv', Conv2d( |
|
8 * n_hid, config.vocab_size, |
|
1, config, use_float16=False)), |
|
]))), |
|
])) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
if len(x.shape) != 4: |
|
raise ValueError(f'input shape {x.shape} is not 4d') |
|
if x.shape[1] != self.input_channels: |
|
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}') |
|
if x.dtype != torch.float32: |
|
raise ValueError('input must have dtype torch.float32') |
|
|
|
return self.blocks(x) |