| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Auto-regressive text decoder in GIT paper. |
| |
| GIT: A Generative Image-to-text Transformer for Vision and Language. Wang et al. |
| |
| arXiv: https://arxiv.org/abs/2205.14100 |
| |
| reference torch implementation: |
| https://github.com/microsoft/GenerativeImage2Text/blob/main/ |
| generativeimage2text/layers/decoder.py |
| |
| """ |
|
|
| from flax import linen as nn |
| import jax |
| import jax.numpy as jnp |
|
|
| from scenic.model_lib.layers import nn_layers |
|
|
| NEG_INF = float('-inf') |
|
|
|
|
| class BertSelfAttention(nn.Module): |
| """Bert layer self attention.""" |
|
|
| num_heads: int = 12 |
| hidden_size: int = 768 |
| attention_dropout: float = 0.1 |
|
|
| @nn.compact |
| def __call__( |
| self, input_tensor, attention_mask, train=False): |
| |
| |
| q = nn.Dense( |
| self.hidden_size, |
| kernel_init=nn.initializers.normal(stddev=0.02), |
| name='query')(input_tensor) |
| k = nn.Dense( |
| self.hidden_size, |
| kernel_init=nn.initializers.normal(stddev=0.02), |
| name='key')(input_tensor) |
| v = nn.Dense( |
| self.hidden_size, |
| kernel_init=nn.initializers.normal(stddev=0.02), |
| name='value')(input_tensor) |
| |
|
|
| head_dim = self.hidden_size // self.num_heads |
| transpose = lambda x: x.reshape( |
| x.shape[0], x.shape[1], self.num_heads, head_dim).transpose(0, 2, 1, 3) |
| q = transpose(q) |
| k = transpose(k) |
| v = transpose(v) |
| attention_scores = (q * (head_dim ** -0.5)) @ k.transpose( |
| 0, 1, 3, 2) |
| attention_scores = attention_scores + attention_mask |
| attention_scores = jax.nn.softmax(attention_scores, axis=-1) |
| attention_scores = nn.Dropout(self.attention_dropout)( |
| attention_scores, deterministic=not train) |
| out = (attention_scores @ v).transpose(0, 2, 1, 3).reshape( |
| v.shape[0], v.shape[2], self.hidden_size) |
| return out |
|
|
|
|
| class BertSelfOutput(nn.Module): |
| """Bert layer self output.""" |
|
|
| hidden_size: int = 768 |
| hidden_dropout: float = 0.1 |
| stochastic_depth: float = 0.0 |
|
|
| @nn.compact |
| def __call__(self, hidden_states, input_tensor, train=False): |
| hidden_states = nn.Dense( |
| self.hidden_size, |
| kernel_init=nn.initializers.normal(stddev=0.02), |
| name='dense')(hidden_states) |
| hidden_states = nn.Dropout(self.hidden_dropout)( |
| hidden_states, deterministic=not train) |
| hidden_states = nn_layers.StochasticDepth(self.stochastic_depth)( |
| hidden_states, deterministic=not train) |
| hidden_states = hidden_states + input_tensor |
| hidden_states = nn.LayerNorm( |
| epsilon=1e-5, name='LayerNorm')(hidden_states) |
| return hidden_states |
|
|
|
|
| class BertAttention(nn.Module): |
| """Bert layer attention.""" |
| hidden_size: int = 768 |
| num_heads: int = 12 |
| dropout: float = 0.1 |
| attention_dropout: float = 0.1 |
| stochastic_depth: float = 0.0 |
|
|
| @nn.compact |
| def __call__( |
| self, input_tensor, attention_mask, train=False): |
| self_outputs = BertSelfAttention( |
| num_heads=self.num_heads, |
| hidden_size=self.hidden_size, |
| attention_dropout=self.attention_dropout, |
| name='self')( |
| input_tensor, attention_mask, train=train, |
| ) |
| attention_output = BertSelfOutput( |
| hidden_size=self.hidden_size, |
| hidden_dropout=self.dropout, |
| stochastic_depth=self.stochastic_depth, |
| name='output')( |
| self_outputs, input_tensor, train=train, |
| ) |
| return attention_output |
|
|
|
|
| class BertIntermediate(nn.Module): |
| """Bert layer intermediate.""" |
|
|
| intermediate_size: int = 768 * 4 |
|
|
| @nn.compact |
| def __call__( |
| self, hidden_states, train=False): |
| hidden_states = nn.Dense( |
| self.intermediate_size, |
| kernel_init=nn.initializers.normal(stddev=0.02), |
| name='dense')(hidden_states) |
| hidden_states = nn.gelu(hidden_states, approximate=False) |
| return hidden_states |
|
|
|
|
| class BertOutput(nn.Module): |
| """Bert layer output.""" |
|
|
| hidden_size: int = 768 |
| hidden_dropout: float = 0.1 |
| stochastic_depth: float = 0.0 |
|
|
| @nn.compact |
| def __call__( |
| self, hidden_states, input_tensor, train=False): |
| hidden_states = nn.Dense( |
| self.hidden_size, |
| kernel_init=nn.initializers.normal(stddev=0.02), |
| name='dense')(hidden_states) |
| hidden_states = nn.Dropout(self.hidden_dropout)( |
| hidden_states, deterministic=not train) |
| hidden_states = nn_layers.StochasticDepth(self.stochastic_depth)( |
| hidden_states, deterministic=not train) |
| hidden_states = hidden_states + input_tensor |
| hidden_states = nn.LayerNorm( |
| epsilon=1e-12, name='LayerNorm')( |
| hidden_states) |
| return hidden_states |
|
|
|
|
| class BertLayer(nn.Module): |
| """GIT encoder Layer.""" |
| hidden_size: int = 768 |
| num_heads: int = 12 |
| dropout: float = 0.1 |
| attention_dropout: float = 0.1 |
| stochastic_depth: float = 0.0 |
|
|
| @nn.compact |
| def __call__( |
| self, hidden_states, attention_mask, train=False): |
| """Forward layer. |
| |
| Args: |
| hidden_states: (batch_size, tot_len, hidden_size). |
| attention_mask: (1, 1, tot_len, tot_len). |
| train: bool. |
| Returns: |
| hidden_states: (batch_size, tot_len, hidden_size). |
| """ |
| attention_outputs = BertAttention( |
| num_heads=self.num_heads, |
| hidden_size=self.hidden_size, |
| dropout=self.dropout, |
| attention_dropout=self.attention_dropout, |
| stochastic_depth=self.stochastic_depth, |
| name='attention')( |
| hidden_states, attention_mask, train=train, |
| ) |
| intermediate_output = BertIntermediate( |
| intermediate_size=self.hidden_size * 4, name='intermediate')( |
| attention_outputs, train=train, |
| ) |
| layer_output = BertOutput( |
| hidden_size=self.hidden_size, |
| hidden_dropout=self.dropout, |
| stochastic_depth=self.stochastic_depth, |
| name='output')( |
| intermediate_output, attention_outputs, train=train, |
| ) |
| return layer_output |
|
|
|
|
| class BertEncoder(nn.Module): |
| """GIT Encoder.""" |
| num_hidden_layers: int = 6 |
| hidden_size: int = 768 |
| num_heads: int = 12 |
| stochastic_depth: float = 0.0 |
| dropout: float = 0.1 |
| attention_dropout: float = 0.1 |
|
|
| @nn.compact |
| def __call__( |
| self, hidden_states, attention_mask, train=False): |
| """forward encoder. |
| |
| Args: |
| hidden_states: (batch_size, tot_len, hidden_size). |
| attention_mask: (1, 1, tot_len, tot_len). |
| train: bool. |
| Returns: |
| hidden_states: (batch_size, tot_len, hidden_size). |
| """ |
| assert self.stochastic_depth >= 0.0 and self.stochastic_depth < 1.0 |
| assert self.dropout >= 0.0 and self.dropout < 1.0 |
| assert self.attention_dropout >= 0.0 and self.attention_dropout < 1.0 |
|
|
| for i in range(self.num_hidden_layers): |
| stochastic_depth_layer = ( |
| i / max(self.num_hidden_layers - 1, 1)) * self.stochastic_depth |
| hidden_states = BertLayer( |
| hidden_size=self.hidden_size, |
| num_heads=self.num_heads, |
| stochastic_depth=stochastic_depth_layer, |
| dropout=self.dropout, |
| attention_dropout=self.attention_dropout, |
| name=f'layer.{i}', |
| )(hidden_states, attention_mask, train=train) |
| return hidden_states |
|
|
|
|
| class BertEncoderAsDecoder(nn.Module): |
| """GIT Decoder.""" |
| num_hidden_layers: int = 6 |
| hidden_size: int = 768 |
| num_heads: int = 12 |
|
|
| @nn.compact |
| def __call__( |
| self, tgt, memory, tgt_mask=None, |
| memory_key_padding_mask=None, train=False, return_visual_feature=False,): |
| """forward transformer. |
| |
| Args: |
| tgt: (batch_size, cap_len, hidden_size) |
| memory: (batch_size, feat_len, hidden_size) |
| tgt_mask: (cap_len, cap_len) |
| memory_key_padding_mask: (batch_size, feat_len). Padded is 1, valid is 0. |
| train: bool |
| return_visual_feature: bool |
| Returns: |
| result: (batch_size, cap_len, hidden_size) |
| """ |
| cap_len = tgt.shape[1] |
| feat_len = memory.shape[1] |
| hidden_states = jnp.concatenate( |
| [memory, tgt], axis=1 |
| ) |
| top_left = jnp.zeros((feat_len, feat_len), dtype=jnp.float32) |
| top_right = jnp.full((feat_len, cap_len), NEG_INF, dtype=jnp.float32) |
| bottom_left = jnp.zeros((cap_len, feat_len), dtype=jnp.float32) |
| left = jnp.concatenate([top_left, bottom_left], axis=0) |
| right = jnp.concatenate([top_right, tgt_mask], axis=0) |
|
|
| full_attention_mask = jnp.concatenate( |
| [left, right], |
| axis=1)[None] |
| if memory_key_padding_mask is None: |
| memory_key_padding_mask = jnp.full( |
| (1, memory.shape[1]), False, dtype=bool, |
| ) |
| else: |
| full_attention_mask = jnp.broadcast_to( |
| full_attention_mask, |
| (memory_key_padding_mask.shape[0], |
| full_attention_mask.shape[1], full_attention_mask.shape[2])) |
| zero_negative_infinity = jnp.zeros_like( |
| memory_key_padding_mask, dtype=tgt.dtype) |
| zero_negative_infinity = jnp.where( |
| memory_key_padding_mask, NEG_INF, zero_negative_infinity) |
| origin_left = full_attention_mask[:, :, :feat_len] |
| update = zero_negative_infinity[:, None, :] |
| full_attention_mask = jnp.concatenate( |
| [origin_left + update, full_attention_mask[:, :, feat_len:]], |
| axis=2) |
| full_attention_mask = full_attention_mask[ |
| :, None, :, :] |
|
|
| result = BertEncoder( |
| num_hidden_layers=self.num_hidden_layers, |
| hidden_size=self.hidden_size, |
| num_heads=self.num_heads, |
| name='encoder')( |
| hidden_states=hidden_states, |
| attention_mask=full_attention_mask, |
| train=train, |
| ) |
| if not return_visual_feature: |
| result = result[:, feat_len:] |
| return result |
|
|
|
|
| def generate_future_mask(size): |
| """Generate attention mask.""" |
| mask = jnp.triu(jnp.ones((size, size), jnp.float32), k=1) |
| mask = jnp.where(mask > 0, NEG_INF, 0) |
| return mask |
|
|