File size: 7,495 Bytes
217780a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
perceiver.py
Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents!
Note that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here
to prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use
that to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
References:
    - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
    - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
"""
from typing import Optional, Tuple

import torch
import torch.nn as nn
from einops import rearrange, repeat


class PerceiverResampler(nn.Module):
    def __init__(self, config, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int) -> None:
        """
        Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
        MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
        returns a Tensor of shape [bsz, n_latents, embed_dim].
        :param embed_dim: Dimensionality of embeddings being fed to the Perceiver Resampler (also dimensionality of
                          latent embeddings *returned* by the Perceiver Resampler. Could be e.g., VIT embed_dim, ResNet
                          pool dim, and so on.
        :param depth: Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
        :param n_heads: Number of heads in each Transformer block (for multi-headed self-attention).
        :param head_dim: Dimensionality of each head projection in the Transformer block.
        :param n_latents: Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
        """
        super().__init__()
        self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
        self.qk_layer_norms = config.qk_layer_norms_perceiver

        # Create Latents for Perceiver
        self.latents = nn.Parameter(torch.randn(self.n_latents, self.embed_dim), requires_grad=True)

        self.intermediate_dim = (
            self.embed_dim * 4 if not hasattr(config, "vision_embed_dim") else config.vision_embed_dim * 4
        )
        # Create Transformer Blocks
        self.blocks = nn.ModuleList(
            [
                nn.ModuleList(
                    [
                        PerceiverAttention(self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms),
                        MLP(self.intermediate_dim, config),
                    ]
                )
                for _ in range(depth)
            ]
        )
        self.layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(self, context: torch.Tensor) -> torch.Tensor:
        """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
        latents = repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])

        # Feed through Perceiver Attention blocks...
        for attn, ff in self.blocks:
            latents = attn(context, latents) + latents
            latents = ff(latents) + latents

        return self.layer_norm(latents)


class PerceiverAttention(nn.Module):
    def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool) -> None:
        """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
        super().__init__()
        self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
        self.qk_layer_norms = qk_layer_norms
        # Normalization & Scaling
        self.context_layer_norm = nn.LayerNorm(self.embed_dim)
        self.latents_layer_norm = nn.LayerNorm(self.embed_dim)
        if self.qk_layer_norms:
            self.q_layer_norm = nn.LayerNorm(self.head_dim)
            self.k_layer_norm = nn.LayerNorm(self.head_dim)

        self.qk_scale = self.head_dim**-0.5

        # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
        self.q_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.embed_dim, self.n_heads * self.head_dim, bias=False)

        self.output_proj = nn.Linear(self.n_heads * self.head_dim, embed_dim, bias=False)

    def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
        """
        Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
        :param context: Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
        :param latents: Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
        :return: Tensor of shape [bsz, n_latents, embed_dim] representing attention over latents w/ cross from context.
        """
        context = self.context_layer_norm(context)
        latents = self.latents_layer_norm(latents)

        # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
        #   Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
        q = self.q_proj(latents)
        k = self.k_proj(torch.cat([context, latents], dim=-2))
        v = self.v_proj(torch.cat([context, latents], dim=-2))

        # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
        #   =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
        q, k, v = [rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) for x in (q, k, v)]
        if self.qk_layer_norms:
            q = self.q_layer_norm(q)
            k = self.k_layer_norm(k)

        scores = torch.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
        stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())
        attn = stabilized_scores.softmax(dim=-1)

        # Attend & project back to output...
        resampled = torch.einsum("... i j, ... j d -> ... i d", attn, v)
        return self.output_proj(
            rearrange(resampled, "bsz heads seq embed -> bsz seq (heads embed)", heads=self.n_heads)
        )


class MLP(nn.Module):
    def __init__(self, intermediate_size, config):
        """Simple MLP block with intermediate_size and embedding size"""
        super().__init__()
        self.embed_dim = config.vision_embed_dim
        self.ln = nn.LayerNorm(self.embed_dim)
        self.fc = nn.Linear(self.embed_dim, intermediate_size, bias=False)
        self.act = nn.ReLU()
        self.c_proj = nn.Linear(intermediate_size, self.embed_dim, bias=False)

    def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
        hidden_states = self.ln(hidden_states)
        hidden_states = self.fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.c_proj(hidden_states)

        return hidden_states