import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict from transformers import PreTrainedModel from .configuration_dalle imoprt DallEConfig class Conv2d(nn.Module): def __init__(self, n_in, n_out, kw, config, use_float16=True): super().__init__() assert n_in >= 1 assert n_out >= 1 assert kw >= 1 and kw % 2 == 1 self.n_in = n_in self.n_out = n_out self.kw = kw self.config = config self.use_float16 = use_float16 w = torch.empty( (n_out, n_in, kw, kw), dtype=torch.float32, device=config.device, requires_grad=config.requires_grad, ) w.normal_(std=1 / math.sqrt(n_in * kw ** 2)) b = torch.zeros( (n_out,), dtype=torch.float32, device=config.device, requires_grad=config.requires_grad, ) self.w = nn.Parameter(w) self.b = nn.Parameter(b) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_float16 and 'cuda' in self.w.device.type: if x.dtype != torch.float16: x = x.half() w, b = self.w.half(), self.b.half() else: if x.dtype != torch.float32: x = x.float() w, b = self.w, self.b return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) def extra_repr(self): inner_repr = f"n_in={self.n_in}, n_out={self.n_out}, kw={self.kw}, " inner_repr += f"use_float16={self.use_float16}, " inner_repr += f"device={self.config.device}, " inner_repr += f"requires_grad={self.config.requires_grad}" return inner_repr class EncoderBlock(nn.Module): def __init__(self, n_in, n_out, n_layers, config): super().__init__() assert n_in >= 1 assert n_out >= 1 and n_out % 4 == 0 assert n_layers >= 1 self.n_in = n_in self.n_out = n_out self.n_hid = n_out // 4 self.post_gain = 1 / (n_layers ** 2) if self.n_in != self.n_out: self.id_path = Conv2d(self.n_in, self.n_out, 1, config) else: self.id_path = nn.Identity() self.res_path = nn.Sequential(OrderedDict([ ('relu_1', nn.ReLU()), ('conv_1', Conv2d(self.n_in, self.n_hid, 3, config)), ('relu_2', nn.ReLU()), ('conv_2', Conv2d(self.n_hid, self.n_hid, 3, config)), ('relu_3', nn.ReLU()), ('conv_3', Conv2d(self.n_hid, self.n_hid, 3, config)), ('relu_4', nn.ReLU()), ('conv_4', Conv2d(self.n_hid, self.n_out, 1, config)), ])) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.id_path(x) + self.post_gain * self.res_path(x) class DallEPreTrainedModel(PreTrainedModel): config_class = DallEConfig base_model_prefix="dalle" class DallEEncoder(DallEPreTrainedModel): def __init__(self, config): super().__init__(config) blk_range = range(config.n_blk_per_group) n_layers = config.group_count * config.n_blk_per_group in_channels = config.input_channels n_hid = config.n_hid self.blocks = nn.Sequential(OrderedDict([ ('input', Conv2d(in_channels, n_hid, 7, config)), ('group_1', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', EncoderBlock(n_hid, n_hid, n_layers, config)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_2', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', EncoderBlock( n_hid if i == 0 else 2 * n_hid, 2 * n_hid, n_layers, config)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_3', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', EncoderBlock( 2 * n_hid if i == 0 else 4 * n_hid, 4 * n_hid, n_layers, config)) for i in blk_range], ('pool', nn.MaxPool2d(kernel_size=2)), ]))), ('group_4', nn.Sequential(OrderedDict([ *[(f'block_{i + 1}', EncoderBlock( 4 * n_hid if i == 0 else 8 * n_hid, 8 * n_hid, n_layers, config)) for i in blk_range], ]))), ('output', nn.Sequential(OrderedDict([ ('relu', nn.ReLU()), ('conv', Conv2d( 8 * n_hid, config.vocab_size, 1, config, use_float16=False)), ]))), ])) def forward(self, x: torch.Tensor) -> torch.Tensor: if len(x.shape) != 4: raise ValueError(f'input shape {x.shape} is not 4d') if x.shape[1] != self.input_channels: raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}') if x.dtype != torch.float32: raise ValueError('input must have dtype torch.float32') return self.blocks(x)