dall-e-tokenizer / modeling_dalle.py
jinmang2's picture
Create modeling_dalle.py
eb21b46
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from transformers import PreTrainedModel
from .configuration_dalle imoprt DallEConfig
class Conv2d(nn.Module):
def __init__(self, n_in, n_out, kw, config, use_float16=True):
super().__init__()
assert n_in >= 1
assert n_out >= 1
assert kw >= 1 and kw % 2 == 1
self.n_in = n_in
self.n_out = n_out
self.kw = kw
self.config = config
self.use_float16 = use_float16
w = torch.empty(
(n_out, n_in, kw, kw),
dtype=torch.float32,
device=config.device,
requires_grad=config.requires_grad,
)
w.normal_(std=1 / math.sqrt(n_in * kw ** 2))
b = torch.zeros(
(n_out,),
dtype=torch.float32,
device=config.device,
requires_grad=config.requires_grad,
)
self.w = nn.Parameter(w)
self.b = nn.Parameter(b)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_float16 and 'cuda' in self.w.device.type:
if x.dtype != torch.float16:
x = x.half()
w, b = self.w.half(), self.b.half()
else:
if x.dtype != torch.float32:
x = x.float()
w, b = self.w, self.b
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
def extra_repr(self):
inner_repr = f"n_in={self.n_in}, n_out={self.n_out}, kw={self.kw}, "
inner_repr += f"use_float16={self.use_float16}, "
inner_repr += f"device={self.config.device}, "
inner_repr += f"requires_grad={self.config.requires_grad}"
return inner_repr
class EncoderBlock(nn.Module):
def __init__(self, n_in, n_out, n_layers, config):
super().__init__()
assert n_in >= 1
assert n_out >= 1 and n_out % 4 == 0
assert n_layers >= 1
self.n_in = n_in
self.n_out = n_out
self.n_hid = n_out // 4
self.post_gain = 1 / (n_layers ** 2)
if self.n_in != self.n_out:
self.id_path = Conv2d(self.n_in, self.n_out, 1, config)
else:
self.id_path = nn.Identity()
self.res_path = nn.Sequential(OrderedDict([
('relu_1', nn.ReLU()),
('conv_1', Conv2d(self.n_in, self.n_hid, 3, config)),
('relu_2', nn.ReLU()),
('conv_2', Conv2d(self.n_hid, self.n_hid, 3, config)),
('relu_3', nn.ReLU()),
('conv_3', Conv2d(self.n_hid, self.n_hid, 3, config)),
('relu_4', nn.ReLU()),
('conv_4', Conv2d(self.n_hid, self.n_out, 1, config)),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.id_path(x) + self.post_gain * self.res_path(x)
class DallEPreTrainedModel(PreTrainedModel):
config_class = DallEConfig
base_model_prefix="dalle"
class DallEEncoder(DallEPreTrainedModel):
def __init__(self, config):
super().__init__(config)
blk_range = range(config.n_blk_per_group)
n_layers = config.group_count * config.n_blk_per_group
in_channels = config.input_channels
n_hid = config.n_hid
self.blocks = nn.Sequential(OrderedDict([
('input', Conv2d(in_channels, n_hid, 7, config)),
('group_1', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}',
EncoderBlock(n_hid, n_hid, n_layers, config))
for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_2', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}',
EncoderBlock(
n_hid if i == 0 else 2 * n_hid,
2 * n_hid, n_layers, config))
for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_3', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}',
EncoderBlock(
2 * n_hid if i == 0 else 4 * n_hid,
4 * n_hid, n_layers, config))
for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_4', nn.Sequential(OrderedDict([
*[(f'block_{i + 1}',
EncoderBlock(
4 * n_hid if i == 0 else 8 * n_hid,
8 * n_hid, n_layers, config))
for i in blk_range],
]))),
('output', nn.Sequential(OrderedDict([
('relu', nn.ReLU()),
('conv', Conv2d(
8 * n_hid, config.vocab_size,
1, config, use_float16=False)),
]))),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.input_channels:
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x)