File size: 5,499 Bytes
eb21b46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from transformers import PreTrainedModel
from .configuration_dalle imoprt DallEConfig
class Conv2d(nn.Module):
def __init__(self, n_in, n_out, kw, config, use_float16=True):
super().__init__()
assert n_in >= 1
assert n_out >= 1
assert kw >= 1 and kw % 2 == 1
self.n_in = n_in
self.n_out = n_out
self.kw = kw
self.config = config
self.use_float16 = use_float16
w = torch.empty(
(n_out, n_in, kw, kw),
dtype=torch.float32,
device=config.device,
requires_grad=config.requires_grad,
)
w.normal_(std=1 / math.sqrt(n_in * kw ** 2))
b = torch.zeros(
(n_out,),
dtype=torch.float32,
device=config.device,
requires_grad=config.requires_grad,
)
self.w = nn.Parameter(w)
self.b = nn.Parameter(b)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_float16 and 'cuda' in self.w.device.type:
if x.dtype != torch.float16:
x = x.half()
w, b = self.w.half(), self.b.half()
else:
if x.dtype != torch.float32:
x = x.float()
w, b = self.w, self.b
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
def extra_repr(self):
inner_repr = f"n_in={self.n_in}, n_out={self.n_out}, kw={self.kw}, "
inner_repr += f"use_float16={self.use_float16}, "
inner_repr += f"device={self.config.device}, "
inner_repr += f"requires_grad={self.config.requires_grad}"
return inner_repr
class EncoderBlock(nn.Module):
def __init__(self, n_in, n_out, n_layers, config):
super().__init__()
assert n_in >= 1
assert n_out >= 1 and n_out % 4 == 0
assert n_layers >= 1
self.n_in = n_in
self.n_out = n_out
self.n_hid = n_out // 4
self.post_gain = 1 / (n_layers ** 2)
if self.n_in != self.n_out:
self.id_path = Conv2d(self.n_in, self.n_out, 1, config)
else:
self.id_path = nn.Identity()
self.res_path = nn.Sequential(OrderedDict([
('relu_1', nn.ReLU()),
('conv_1', Conv2d(self.n_in, self.n_hid, 3, config)),
('relu_2', nn.ReLU()),
('conv_2', Conv2d(self.n_hid, self.n_hid, 3, config)),
('relu_3', nn.ReLU()),
('conv_3', Conv2d(self.n_hid, self.n_hid, 3, config)),
('relu_4', nn.ReLU()),
('conv_4', Conv2d(self.n_hid, self.n_out, 1, config)),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.id_path(x) + self.post_gain * self.res_path(x)
class DallEPreTrainedModel(PreTrainedModel):
config_class = DallEConfig
base_model_prefix="dalle"
class DallEEncoder(DallEPreTrainedModel):
def __init__(self, config):
super().__init__(config)
blk_range = range(config.n_blk_per_group)
n_layers = config.group_count * config.n_blk_per_group
in_channels = config.input_channels
n_hid = config.n_hid
self.blocks = nn.Sequential(OrderedDict([
('input', Conv2d(in_channels, n_hid, 7, config)),
('group_1', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}',
EncoderBlock(n_hid, n_hid, n_layers, config))
for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_2', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}',
EncoderBlock(
n_hid if i == 0 else 2 * n_hid,
2 * n_hid, n_layers, config))
for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_3', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}',
EncoderBlock(
2 * n_hid if i == 0 else 4 * n_hid,
4 * n_hid, n_layers, config))
for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_4', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}',
EncoderBlock(
4 * n_hid if i == 0 else 8 * n_hid,
8 * n_hid, n_layers, config))
for i in blk_range],
]))),
('output', nn.Sequential(OrderedDict([
('relu', nn.ReLU()),
('conv', Conv2d(
8 * n_hid, config.vocab_size,
1, config, use_float16=False)),
]))),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.input_channels:
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x) |