boris's picture
fix: causal_mask based on image tokens
8654dc9
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and the DalleBart team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" DalleBart model. """
import math
from functools import partial
from typing import Optional
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import unfreeze
from flax.linen import make_causal_mask
from flax.traverse_util import flatten_dict
from jax.random import PRNGKey
from transformers.modeling_flax_outputs import (
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
)
from transformers.modeling_flax_utils import ACT2FN
from transformers.utils import logging
from transformers.models.bart.modeling_flax_bart import (
FlaxBartAttention,
FlaxBartEncoderLayer,
FlaxBartDecoderLayer,
FlaxBartEncoderLayerCollection,
FlaxBartDecoderLayerCollection,
FlaxBartEncoder,
FlaxBartDecoder,
FlaxBartModule,
FlaxBartForConditionalGenerationModule,
FlaxBartPreTrainedModel,
FlaxBartForConditionalGeneration,
)
logger = logging.get_logger(__name__)
class FlaxBartAttention(FlaxBartAttention):
"""
Edits:
- causal mask is used only in decoder and considers image_length + 1 (for BOS)
"""
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.out_proj = dense()
self.dropout_layer = nn.Dropout(rate=self.dropout)
if self.causal:
# used only in decoder
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.image_length + 1), dtype="bool"), dtype="bool"
)
class FlaxBartEncoderLayer(FlaxBartEncoderLayer):
"""
Edits:
- no bias
- use custom FlaxBartAttention
"""
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
bias=False,
dtype=self.dtype,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim,
dtype=self.dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
class FlaxBartEncoderLayerCollection(FlaxBartEncoderLayerCollection):
"""
Edits:
- use custom FlaxBartEncoderLayer
- allow Gradient Checkpointing (nn.remat)
"""
def setup(self):
layer_module = (
nn.remat(FlaxBartEncoderLayer)
if self.config.gradient_checkpointing
else FlaxBartEncoderLayer
)
self.layers = [
layer_module(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.encoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop
class FlaxBartDecoderLayer(FlaxBartDecoderLayer):
"""
Edits:
- no bias
- uses custom FlaxBartAttention
"""
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
bias=False,
dtype=self.dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.encoder_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
bias=False,
dtype=self.dtype,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.fc2 = nn.Dense(
self.embed_dim,
dtype=self.dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
class FlaxBartDecoderLayerCollection(FlaxBartDecoderLayerCollection):
"""
Edits:
- use custom FlaxBartDecoderLayer
- allow Gradient Checkpointing (nn.remat)
"""
def setup(self):
layer_module = (
nn.remat(FlaxBartDecoderLayer)
if self.config.gradient_checkpointing
else FlaxBartDecoderLayer
)
self.layers = [
layer_module(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.decoder_layers)
]
self.layerdrop = self.config.decoder_layerdrop
class FlaxBartEncoder(FlaxBartEncoder):
"""
Edits:
- offset set to 0 (no padding token)
- use max_text_length instead of max_position_embeddings
- use custom FlaxBartEncoderLayerCollection
- embed_tokens cannot be None (issue at compile time)
"""
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 0
self.embed_positions = nn.Embed(
self.config.max_text_length + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
class FlaxBartDecoder(FlaxBartDecoder):
"""
Edits:
- offset set to 0 (no padding token)
- use image_length + 1 (for BOS) instead of max_position_embeddings
- use custom FlaxBartDecoderLayerCollection
- embed_tokens cannot be None (issue at compile time)
"""
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.embed_scale = (
math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 0
self.embed_positions = nn.Embed(
self.config.image_length + 1 + self.offset, # image length + 1 for BOS
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
class FlaxBartModule(FlaxBartModule):
"""
Edits
- use custom FlaxBartEncoder & FlaxBartDecoder
- use separate embeddings for Encoder & Decoder
"""
def setup(self):
encoder_embed_tokens = nn.Embed(
self.config.encoder_vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
decoder_embed_tokens = nn.Embed(
self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.encoder = FlaxBartEncoder(
self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens
)
self.decoder = FlaxBartDecoder(
self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens
)
class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
"""
Edits:
- added num_params property
"""
@property
def num_params(self):
num_params = jax.tree_map(
lambda param: param.size, flatten_dict(unfreeze(self.params))
).values()
return sum(list(num_params))
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
"""
Edits:
- no bias
- lm_head set to image_vocab_size + 1 (for BOS)
- uses custom FlaxBartModule
"""
def setup(self):
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.model.variables["params"]["shared"]["embedding"]
lm_logits = self.lm_head.apply(
{"params": {"kernel": shared_embedding.T}}, hidden_states
)
else:
lm_logits = self.lm_head(hidden_states)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return output
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
"""
Edits:
- renamed from FlaxBartForConditionalGeneration
- uses custom FlaxBartPreTrainedModel
- uses custom FlaxBartForConditionalGenerationModule
- no bias in decode method
"""
module_class = FlaxBartForConditionalGenerationModule
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.return_dict
)
encoder_hidden_states = encoder_outputs[0]
if encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = decoder_input_ids.shape
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError(
"Make sure to provide `decoder_position_ids` when passing `past_key_values`."
)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxBartAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(
module,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
):
decoder_module = module._get_decoder_module()
outputs = decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = module.model.variables["params"]["shared"][
"embedding"
]
lm_logits = module.lm_head.apply(
{"params": {"kernel": shared_embedding.T}}, hidden_states
)
else:
lm_logits = module.lm_head(hidden_states)
return lm_logits, outputs
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
if past_key_values is None:
lm_logits, decoder_outputs = outputs
else:
(lm_logits, decoder_outputs), past = outputs
if return_dict:
outputs = FlaxCausalLMOutputWithCrossAttentions(
logits=lm_logits,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
)
else:
outputs = (lm_logits,) + decoder_outputs[1:]
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs