import math from collections import OrderedDict from typing import List, Optional, Tuple, cast import attr import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from .attention import ( AttentionInfo, DenseAttentionMask, DenseCausalAttentionMask, make_full_layout, to_attention_info, ) from .utils import Affine, LayerNorm, zero_key_bias_grad # Constants used in the original CLIP implementation. image_channel_means = [122.77093945, 116.74601272, 104.09373519] image_channel_stds = [68.50053285, 66.63215831, 70.32316309] @attr.s(eq=False, repr=False) class TextEmbedding(nn.Module): n_vocab: int = attr.ib() n_context: int = attr.ib() n_state: int = attr.ib() device: torch.device = attr.ib(default=torch.device("cuda")) def __attrs_post_init__(self) -> None: super().__init__() w_voc = torch.empty((self.n_vocab, self.n_state), dtype=torch.float32, device=self.device) w_pos = torch.empty((self.n_context, self.n_state), dtype=torch.float32, device=self.device) with torch.no_grad(): w_voc.normal_(std=0.02) w_pos.normal_(std=0.01) self.w_voc = nn.Parameter(w_voc) self.w_pos = nn.Parameter(w_pos) def forward(self, x: torch.Tensor) -> torch.Tensor: if len(x.shape) != 2: raise ValueError() return F.embedding(x, self.w_voc) + self.w_pos[None, :, :] @attr.s(eq=False, repr=False) class ImageEmbedding(nn.Module): image_size: int = attr.ib() patch_size: int = attr.ib() n_state: int = attr.ib() n_timestep: int = attr.ib(default=0) device: torch.device = attr.ib(default=torch.device("cuda")) def __attrs_post_init__(self) -> None: super().__init__() if self.image_size % self.patch_size != 0: raise ValueError() n_patch = self.image_size // self.patch_size patch_proj = torch.empty( (self.n_state, 3) + 2 * (self.patch_size,), dtype=torch.float32, device=self.device ) w_pos = torch.empty( (1 + n_patch ** 2, self.n_state), dtype=torch.float32, device=self.device ) with torch.no_grad(): if self.n_timestep == 0: pred_state = torch.empty((self.n_state,), dtype=torch.float32, device=self.device) pred_state.normal_(std=1 / np.sqrt(self.n_state)) self.pred_state = nn.Parameter(pred_state) else: w_t = torch.empty( (self.n_timestep, self.n_state), dtype=torch.float32, device=self.device ) w_t.normal_(std=1 / np.sqrt(self.n_state)) self.w_t = nn.Parameter(w_t) patch_proj.normal_(std=np.sqrt(2 / (self.n_state * self.patch_size ** 2))) w_pos.normal_(std=1 / np.sqrt(self.n_state)) self.patch_proj = nn.Parameter(patch_proj) self.w_pos = nn.Parameter(w_pos) self.channel_means = torch.tensor( image_channel_means, dtype=torch.float32, device=self.device )[None, :, None, None] self.channel_stds = torch.tensor( image_channel_stds, dtype=torch.float32, device=self.device )[None, :, None, None] self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None) -> torch.Tensor: if len(x.shape) != 4: raise ValueError("input should be 4d") if x.shape[1] != 3: raise ValueError("input should have 3 channels") if not (x.shape[2] == self.image_size and x.shape[3] == self.image_size): raise ValueError(f"input is not {self.image_size} x {self.image_size}") if (self.n_timestep == 0 and t is not None) or (self.n_timestep != 0 and t is None): raise ValueError() if self.n_timestep != 0: assert t is not None if len(t.shape) != 1: raise ValueError() if t.shape[0] != x.shape[0]: raise ValueError() x = (x - self.channel_means) / self.channel_stds x = F.conv2d(x, self.patch_proj, stride=self.patch_size) x = x.reshape(x.shape[0], self.n_state, (self.image_size // self.patch_size) ** 2).permute( 0, 2, 1 ) sot = ( self.pred_state[None, None].expand(x.shape[0], -1, -1) if self.n_timestep == 0 else F.embedding(cast(torch.Tensor, t), self.w_t)[:, None] ) x = torch.cat((sot, x), dim=1) + self.w_pos[None] return self.ln(x) @attr.s(eq=False, repr=False) class AttentionResblock(nn.Module): n_state: int = attr.ib() n_resblocks: int = attr.ib() attn_fn: AttentionInfo = attr.ib() device: torch.device = attr.ib(default=torch.device("cuda")) def __attrs_post_init__(self) -> None: super().__init__() self.n_head_state = self.n_state // self.attn_fn.n_heads self.qk_scale = 1 / np.sqrt(self.n_head_state) self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) self.f_q = Affine( self.n_state, self.n_state, std=1 / math.sqrt(self.n_state), use_bias=True, bias_filter_fn=zero_key_bias_grad, device=self.device, ) self.f_k = Affine( self.n_state, self.n_state, std=1 / math.sqrt(self.n_state), use_bias=False, bias_filter_fn=zero_key_bias_grad, device=self.device, ) self.f_v = Affine( self.n_state, self.n_state, std=1 / math.sqrt(self.n_state), use_bias=True, bias_filter_fn=zero_key_bias_grad, device=self.device, ) self.f_c = Affine( self.n_state, self.n_state, use_bias=True, std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2), device=self.device, ) # XXX def forward(self, m: torch.Tensor) -> torch.Tensor: n_context = m.shape[1] n_query_pad = self.attn_fn.ctx_blks_q * self.attn_fn.block_size - n_context n_key_pad = self.attn_fn.ctx_blks_k * self.attn_fn.block_size - n_context assert n_query_pad >= 0 assert n_key_pad >= 0 r = m r = self.ln(r) q, k, v = self.f_q(r), self.f_k(r), self.f_v(r) if n_query_pad != 0: q = F.pad(q, (0, 0, 0, n_query_pad)) if n_key_pad != 0: k = F.pad(k, (0, 0, 0, n_key_pad)) v = F.pad(v, (0, 0, 0, n_key_pad)) q = q.view([q.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) k = k.view([k.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) v = v.view([v.shape[0], -1, self.attn_fn.n_heads, self.n_head_state]).permute((0, 2, 1, 3)) w = torch.einsum( "bhcd,bhkd->bhck", q * math.sqrt(self.qk_scale), k * math.sqrt(self.qk_scale) ) if hasattr(self.attn_fn, "pytorch_attn_bias"): bias = self.attn_fn.pytorch_attn_bias assert len(bias.shape) in {2, 3} if len(bias.shape) == 2: w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None, None], dim=-1) elif len(bias.shape) == 3: w = torch.softmax(w + self.attn_fn.pytorch_attn_bias[None], dim=-1) else: w = torch.softmax(w, dim=-1) r = torch.einsum("bhck,bhkd->bhcd", w, v) r = r.permute((0, 2, 1, 3)).reshape((r.shape[0], -1, self.n_state)) if n_query_pad != 0: r = r[:, :-n_query_pad] assert r.shape[1] == n_context r = self.f_c(r) return m + r @attr.s(eq=False, repr=False) class FullyConnectedResblock(nn.Module): """ Not imported from other files because we retain Alec's original inits. """ n_state: int = attr.ib() n_resblocks: int = attr.ib() device: torch.device = attr.ib(default=torch.device("cuda")) def __attrs_post_init__(self) -> None: super().__init__() self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) self.f_1 = Affine( self.n_state, 4 * self.n_state, use_bias=True, std=np.sqrt(2 / (4 * self.n_state)), device=self.device, ) self.f_2 = Affine( 4 * self.n_state, self.n_state, use_bias=True, std=1 / np.sqrt(self.n_state * self.n_resblocks ** 2), device=self.device, ) # XXX def forward(self, m: torch.Tensor) -> torch.Tensor: r = m r = self.ln(r) r = self.f_2(F.gelu(self.f_1(r))) return m + r @attr.s(eq=False, repr=False) class TransformerBlock(nn.Module): n_state: int = attr.ib() n_resblocks: int = attr.ib() attn_fn: AttentionInfo = attr.ib() device: torch.device = attr.ib(default=torch.device("cuda")) def __attrs_post_init__(self) -> None: super().__init__() self.f_attn = AttentionResblock( self.n_state, self.n_resblocks, self.attn_fn, self.device, ) self.f_mlp = FullyConnectedResblock(self.n_state, self.n_resblocks, self.device) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.f_mlp(self.f_attn(x)) @attr.s(eq=False, repr=False) class TextFeatureExtractor(nn.Module): n_state: int = attr.ib() n_embd: int = attr.ib() device: torch.device = attr.ib(default=torch.device("cuda")) def __attrs_post_init__(self) -> None: super().__init__() self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device) def forward( self, text: torch.Tensor, text_len: torch.Tensor, return_probe_features: bool = False ) -> torch.Tensor: if len(text.shape) != 3: raise ValueError("expected text to be 3d") if len(text_len.shape) != 1: raise ValueError("expected text length to be 1d") if text.shape[0] != text_len.shape[0]: raise ValueError("text and text_len have inconsistent batch dimensions") index = (text_len - 1)[:, None, None].expand(-1, 1, text.shape[2]) x = torch.gather(text, dim=1, index=index) assert list(x.shape) == [text.shape[0], 1, text.shape[2]] if return_probe_features: return x[:, 0] x = self.ln(x) return self.f(x[:, 0]) @attr.s(eq=False, repr=False) class ImageFeatureExtractor(nn.Module): n_state: int = attr.ib() n_embd: int = attr.ib() device: torch.device = attr.ib(default=torch.device("cuda")) def __attrs_post_init__(self) -> None: super().__init__() self.ln = LayerNorm(self.n_state, eps=1e-5, device=self.device) self.f = Affine(self.n_state, self.n_embd, use_bias=False, device=self.device) def forward(self, x: torch.Tensor, return_probe_features: bool = False) -> torch.Tensor: if return_probe_features: return x[:, 0] x = self.ln(x[:, :1]) return self.f(x[:, 0]) @attr.s(eq=False, repr=False) class TextEncoder(nn.Module): n_bpe_vocab: int = attr.ib() max_text_len: int = attr.ib() n_embd: int = attr.ib() n_head: int = attr.ib() n_xf_blocks: int = attr.ib() n_head_state: int = attr.ib(default=64) device: torch.device = attr.ib(default=torch.device("cuda")) block_size: int = attr.ib(init=False, default=32) def __attrs_post_init__(self) -> None: super().__init__() self.n_state = self.n_head * self.n_head_state n_rounded_context = self.block_size * int(math.ceil(self.max_text_len / self.block_size)) n_pad = n_rounded_context - self.max_text_len args = ( n_rounded_context, n_rounded_context, self.block_size, self.n_head, False, n_pad, n_pad, ) mask = DenseCausalAttentionMask(*args) attn_fn = to_attention_info(mask) m = 1 - make_full_layout(mask).astype(np.float32) m[m == 1] = -1e10 attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device) blocks: List[Tuple[str, nn.Module]] = [ ( "input", TextEmbedding( self.n_bpe_vocab, self.max_text_len, self.n_state, device=self.device ), ) ] for i in range(self.n_xf_blocks): blocks.append( ( f"block_{i}", TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device), ) ) blocks.append( ("output", TextFeatureExtractor(self.n_state, self.n_embd, device=self.device)) ) self.blocks = nn.ModuleDict(OrderedDict(blocks)) def forward( self, text: torch.Tensor, text_len: torch.Tensor, return_probe_features: bool = False, ) -> torch.Tensor: n_batch = text.shape[0] h = self.blocks["input"](text) for i in range(self.n_xf_blocks): h = self.blocks[f"block_{i}"](h) h = self.blocks["output"](h, text_len, return_probe_features=return_probe_features) assert list(h.shape) == [ n_batch, self.n_embd if not return_probe_features else self.n_state, ] return h @attr.s(eq=False, repr=False) class ImageEncoder(nn.Module): image_size: int = attr.ib() patch_size: int = attr.ib() n_embd: int = attr.ib() n_head: int = attr.ib() n_xf_blocks: int = attr.ib() n_head_state: int = attr.ib(default=64) n_timestep: int = attr.ib(default=0) device: torch.device = attr.ib(default=torch.device("cuda")) block_size: int = attr.ib(init=False, default=32) def __attrs_post_init__(self) -> None: super().__init__() self.n_state = self.n_head * self.n_head_state self.n_context = 1 + (self.image_size // self.patch_size) ** 2 n_rounded_context = self.block_size * int(math.ceil(self.n_context / self.block_size)) n_pad = n_rounded_context - self.n_context args = ( n_rounded_context, n_rounded_context, self.block_size, self.n_head, False, n_pad, n_pad, ) mask = DenseAttentionMask(*args) attn_fn = to_attention_info(mask) m = 1 - make_full_layout(mask).astype(np.float32) m[m == 1] = -1e10 attn_fn.pytorch_attn_bias = torch.from_numpy(m).to(self.device) blocks: List[Tuple[str, nn.Module]] = [ ( "input", ImageEmbedding( self.image_size, self.patch_size, self.n_state, n_timestep=self.n_timestep, device=self.device, ), ) ] for i in range(self.n_xf_blocks): blocks.append( ( f"block_{i}", TransformerBlock(self.n_state, 2 * self.n_xf_blocks, attn_fn, self.device), ) ) blocks.append(("output", ImageFeatureExtractor(self.n_state, self.n_embd, self.device))) self.blocks = nn.ModuleDict(OrderedDict(blocks)) def forward( self, image: torch.Tensor, timesteps: Optional[torch.Tensor] = None, return_probe_features: bool = False, ) -> torch.Tensor: n_batch = image.shape[0] h = self.blocks["input"](image, t=timesteps) for i in range(self.n_xf_blocks): h = self.blocks[f"block_{i}"](h) h = self.blocks["output"](h, return_probe_features=return_probe_features) assert list(h.shape) == [ n_batch, self.n_embd if not return_probe_features else self.n_state, ] return h