from typing import Callable, Optional, Tuple from copy import deepcopy import numpy as np import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from transformers import AlbertConfig from transformers.models.albert.modeling_flax_albert import FlaxAlbertOnlyMLMHead, FlaxAlbertEmbeddings, FlaxAlbertPreTrainedModel from transformers.modeling_flax_outputs import ( FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxMaskedLMOutput, FlaxMultipleChoiceModelOutput, FlaxQuestionAnsweringModelOutput, FlaxSequenceClassifierOutput, FlaxTokenClassifierOutput, ) from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.modeling_flax_utils import ( ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, append_replace_return_docstrings, overwrite_call_docstring, ) class CustomFlaxAlbertSelfAttention(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): if self.config.hidden_size % self.config.num_attention_heads != 0: raise ValueError( "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " " : {self.config.num_attention_heads}" ) self.query = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__( self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False, layer_id: int = None, interv_type: str = "swap", interv_dict: dict = {}, ): head_dim = self.config.hidden_size // self.config.num_attention_heads query_states = self.query(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) value_states = self.value(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) key_states = self.key(hidden_states).reshape( hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) ) reps = { 'lay': hidden_states, 'qry': query_states, 'key': key_states, 'val': value_states, } if layer_id in interv_dict: interv = interv_dict[layer_id] for rep_name in ['lay','qry','key','val']: if rep_name in interv: new_state = deepcopy(reps[rep_name]) for head_id, pos, swap_ids in interv[rep_name]: new_state[swap_ids[0],pos,head_id] = reps[rep_name][swap_ids[1],pos,head_id] new_state[swap_ids[1],pos,head_id] = reps[rep_name][swap_ids[0],pos,head_id] reps[rep_name] = deepcopy(new_state) hidden_states = deepcopy(reps['lay']) query_states = deepcopy(reps['qry']) key_states = deepcopy(reps['key']) value_states = deepcopy(reps['val']) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: # attention mask in the form of attention bias attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None dropout_rng = None if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( query_states, key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, broadcast_dropout=True, deterministic=deterministic, dtype=self.dtype, precision=None, ) attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) projected_attn_output = self.dense(attn_output) projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic) layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states) outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,) return outputs class CustomFlaxAlbertLayer(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.attention = CustomFlaxAlbertSelfAttention(self.config, dtype=self.dtype) self.ffn = nn.Dense( self.config.intermediate_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] self.ffn_output = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, layer_id: int = None, interv_type: str = "swap", interv_dict: dict = {}, ): attention_outputs = self.attention( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, layer_id=layer_id, interv_type=interv_type, interv_dict=interv_dict, ) attention_output = attention_outputs[0] ffn_output = self.ffn(attention_output) ffn_output = self.activation(ffn_output) ffn_output = self.ffn_output(ffn_output) ffn_output = self.dropout(ffn_output, deterministic=deterministic) hidden_states = self.full_layer_layer_norm(ffn_output + attention_output) outputs = (hidden_states,) if output_attentions: outputs += (attention_outputs[1],) return outputs class CustomFlaxAlbertLayerCollection(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.layers = [ CustomFlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num) ] def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, layer_id: int = None, interv_type: str = "swap", interv_dict: dict = {}, ): layer_hidden_states = () layer_attentions = () for layer_index, albert_layer in enumerate(self.layers): layer_output = albert_layer( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, layer_id=layer_id, interv_type=interv_type, interv_dict=interv_dict, ) hidden_states = layer_output[0] if output_attentions: layer_attentions = layer_attentions + (layer_output[1],) if output_hidden_states: layer_hidden_states = layer_hidden_states + (hidden_states,) outputs = (hidden_states,) if output_hidden_states: outputs = outputs + (layer_hidden_states,) if output_attentions: outputs = outputs + (layer_attentions,) return outputs # last-layer hidden state, (layer hidden states), (layer attentions) class CustomFlaxAlbertLayerCollections(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation layer_index: Optional[str] = None def setup(self): self.albert_layers = CustomFlaxAlbertLayerCollection(self.config, dtype=self.dtype) def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, layer_id: int = None, interv_type: str = "swap", interv_dict: dict = {}, ): outputs = self.albert_layers( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, layer_id=layer_id, interv_type=interv_type, interv_dict=interv_dict, ) return outputs class CustomFlaxAlbertLayerGroups(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.layers = [ CustomFlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_groups) ] def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, interv_type: str = "swap", interv_dict: dict = {}, ): all_attentions = () if output_attentions else None all_hidden_states = (hidden_states,) if output_hidden_states else None for i in range(self.config.num_hidden_layers): # Index of the hidden group group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) layer_group_output = self.layers[group_idx]( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, layer_id=i, interv_type=interv_type, interv_dict=interv_dict, ) hidden_states = layer_group_output[0] if output_attentions: all_attentions = all_attentions + layer_group_output[-1] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return FlaxBaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions ) class CustomFlaxAlbertEncoder(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.embedding_hidden_mapping_in = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.albert_layer_groups = CustomFlaxAlbertLayerGroups(self.config, dtype=self.dtype) def __call__( self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, interv_type: str = "swap", interv_dict: dict = {}, ): hidden_states = self.embedding_hidden_mapping_in(hidden_states) return self.albert_layer_groups( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interv_type=interv_type, interv_dict=interv_dict, ) class CustomFlaxAlbertModule(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True def setup(self): self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype) self.encoder = CustomFlaxAlbertEncoder(self.config, dtype=self.dtype) if self.add_pooling_layer: self.pooler = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, name="pooler", ) self.pooler_activation = nn.tanh else: self.pooler = None self.pooler_activation = None def __call__( self, input_ids, attention_mask, token_type_ids: Optional[np.ndarray] = None, position_ids: Optional[np.ndarray] = None, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, interv_type: str = "swap", interv_dict: dict = {}, ): # make sure `token_type_ids` is correctly initialized when not passed if token_type_ids is None: token_type_ids = jnp.zeros_like(input_ids) # make sure `position_ids` is correctly initialized when not passed if position_ids is None: position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic) outputs = self.encoder( hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interv_type=interv_type, interv_dict=interv_dict, ) hidden_states = outputs[0] if self.add_pooling_layer: pooled = self.pooler(hidden_states[:, 0]) pooled = self.pooler_activation(pooled) else: pooled = None if not return_dict: # if pooled is None, don't return it if pooled is None: return (hidden_states,) + outputs[1:] return (hidden_states, pooled) + outputs[1:] return FlaxBaseModelOutputWithPooling( last_hidden_state=hidden_states, pooler_output=pooled, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class CustomFlaxAlbertForMaskedLMModule(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.albert = CustomFlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, interv_type: str = "swap", interv_dict: dict = {}, ): # Model outputs = self.albert( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interv_type=interv_type, interv_dict=interv_dict, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] else: shared_embedding = None # Compute the prediction scores logits = self.predictions(hidden_states, shared_embedding=shared_embedding) if not return_dict: return (logits,) + outputs[1:] return FlaxMaskedLMOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class CustomFlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel): module_class = CustomFlaxAlbertForMaskedLMModule