taka-yamakoshi commited on
Commit
a4fb159
1 Parent(s): 0c71efa

fix order of classes

Browse files
Files changed (1) hide show
  1. custom_modeling_albert_flax.py +246 -246
custom_modeling_albert_flax.py CHANGED
@@ -32,146 +32,118 @@ from transformers.modeling_flax_utils import (
32
  overwrite_call_docstring,
33
  )
34
 
35
- class CustomFlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
36
- module_class = CustomFlaxAlbertForMaskedLMModule
37
-
38
- class CustomFlaxAlbertForMaskedLMModule(nn.Module):
39
  config: AlbertConfig
40
- dtype: jnp.dtype = jnp.float32
41
 
42
  def setup(self):
43
- self.albert = CustomFlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
44
- self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def __call__(
47
  self,
48
- input_ids,
49
  attention_mask,
50
- token_type_ids,
51
- position_ids,
52
- deterministic: bool = True,
53
  output_attentions: bool = False,
54
- output_hidden_states: bool = False,
55
- return_dict: bool = True,
56
  interv_type: str = "swap",
57
  interv_dict: dict = {},
58
  ):
59
- # Model
60
- outputs = self.albert(
61
- input_ids,
62
- attention_mask,
63
- token_type_ids,
64
- position_ids,
65
- deterministic=deterministic,
66
- output_attentions=output_attentions,
67
- output_hidden_states=output_hidden_states,
68
- return_dict=return_dict,
69
- interv_type=interv_type,
70
- interv_dict=interv_dict,
71
- )
72
-
73
- hidden_states = outputs[0]
74
- if self.config.tie_word_embeddings:
75
- shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
76
- else:
77
- shared_embedding = None
78
-
79
- # Compute the prediction scores
80
- logits = self.predictions(hidden_states, shared_embedding=shared_embedding)
81
-
82
- if not return_dict:
83
- return (logits,) + outputs[1:]
84
 
85
- return FlaxMaskedLMOutput(
86
- logits=logits,
87
- hidden_states=outputs.hidden_states,
88
- attentions=outputs.attentions,
 
 
 
 
89
  )
90
 
91
- class CustomFlaxAlbertModule(nn.Module):
92
- config: AlbertConfig
93
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
94
- add_pooling_layer: bool = True
95
-
96
- def setup(self):
97
- self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
98
- self.encoder = CustomFlaxAlbertEncoder(self.config, dtype=self.dtype)
99
- if self.add_pooling_layer:
100
- self.pooler = nn.Dense(
101
- self.config.hidden_size,
102
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
103
- dtype=self.dtype,
104
- name="pooler",
105
  )
106
- self.pooler_activation = nn.tanh
107
  else:
108
- self.pooler = None
109
- self.pooler_activation = None
110
-
111
- def __call__(
112
- self,
113
- input_ids,
114
- attention_mask,
115
- token_type_ids: Optional[np.ndarray] = None,
116
- position_ids: Optional[np.ndarray] = None,
117
- deterministic: bool = True,
118
- output_attentions: bool = False,
119
- output_hidden_states: bool = False,
120
- return_dict: bool = True,
121
- interv_type: str = "swap",
122
- interv_dict: dict = {},
123
- ):
124
- # make sure `token_type_ids` is correctly initialized when not passed
125
- if token_type_ids is None:
126
- token_type_ids = jnp.zeros_like(input_ids)
127
-
128
- # make sure `position_ids` is correctly initialized when not passed
129
- if position_ids is None:
130
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
131
 
132
- hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
 
 
133
 
134
- outputs = self.encoder(
135
- hidden_states,
136
- attention_mask,
 
 
 
 
137
  deterministic=deterministic,
138
- output_attentions=output_attentions,
139
- output_hidden_states=output_hidden_states,
140
- return_dict=return_dict,
141
- interv_type=interv_type,
142
- interv_dict=interv_dict,
143
  )
144
- hidden_states = outputs[0]
145
- if self.add_pooling_layer:
146
- pooled = self.pooler(hidden_states[:, 0])
147
- pooled = self.pooler_activation(pooled)
148
- else:
149
- pooled = None
150
 
151
- if not return_dict:
152
- # if pooled is None, don't return it
153
- if pooled is None:
154
- return (hidden_states,) + outputs[1:]
155
- return (hidden_states, pooled) + outputs[1:]
156
 
157
- return FlaxBaseModelOutputWithPooling(
158
- last_hidden_state=hidden_states,
159
- pooler_output=pooled,
160
- hidden_states=outputs.hidden_states,
161
- attentions=outputs.attentions,
162
- )
163
 
164
- class CustomFlaxAlbertEncoder(nn.Module):
165
  config: AlbertConfig
166
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
167
 
168
  def setup(self):
169
- self.embedding_hidden_mapping_in = nn.Dense(
 
 
 
 
 
 
 
170
  self.config.hidden_size,
171
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
172
  dtype=self.dtype,
173
  )
174
- self.albert_layer_groups = CustomFlaxAlbertLayerGroups(self.config, dtype=self.dtype)
 
175
 
176
  def __call__(
177
  self,
@@ -179,30 +151,39 @@ class CustomFlaxAlbertEncoder(nn.Module):
179
  attention_mask,
180
  deterministic: bool = True,
181
  output_attentions: bool = False,
182
- output_hidden_states: bool = False,
183
- return_dict: bool = True,
184
  interv_type: str = "swap",
185
  interv_dict: dict = {},
186
  ):
187
- hidden_states = self.embedding_hidden_mapping_in(hidden_states)
188
- return self.albert_layer_groups(
189
  hidden_states,
190
  attention_mask,
191
  deterministic=deterministic,
192
  output_attentions=output_attentions,
193
- output_hidden_states=output_hidden_states,
194
  interv_type=interv_type,
195
  interv_dict=interv_dict,
196
  )
 
 
 
 
 
 
197
 
198
- class CustomFlaxAlbertLayerGroups(nn.Module):
 
 
 
 
 
 
199
  config: AlbertConfig
200
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
201
 
202
  def setup(self):
203
  self.layers = [
204
- CustomFlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
205
- for i in range(self.config.num_hidden_groups)
206
  ]
207
 
208
  def __call__(
@@ -212,39 +193,37 @@ class CustomFlaxAlbertLayerGroups(nn.Module):
212
  deterministic: bool = True,
213
  output_attentions: bool = False,
214
  output_hidden_states: bool = False,
215
- return_dict: bool = True,
216
  interv_type: str = "swap",
217
  interv_dict: dict = {},
218
  ):
219
- all_attentions = () if output_attentions else None
220
- all_hidden_states = (hidden_states,) if output_hidden_states else None
221
 
222
- for i in range(self.config.num_hidden_layers):
223
- # Index of the hidden group
224
- group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
225
- layer_group_output = self.layers[group_idx](
226
  hidden_states,
227
  attention_mask,
228
  deterministic=deterministic,
229
  output_attentions=output_attentions,
230
- output_hidden_states=output_hidden_states,
231
- layer_id=i,
232
  interv_type=interv_type,
233
  interv_dict=interv_dict,
234
  )
235
- hidden_states = layer_group_output[0]
236
 
237
  if output_attentions:
238
- all_attentions = all_attentions + layer_group_output[-1]
239
 
240
  if output_hidden_states:
241
- all_hidden_states = all_hidden_states + (hidden_states,)
242
 
243
- if not return_dict:
244
- return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
245
- return FlaxBaseModelOutput(
246
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
247
- )
 
248
 
249
  class CustomFlaxAlbertLayerCollections(nn.Module):
250
  config: AlbertConfig
@@ -277,13 +256,14 @@ class CustomFlaxAlbertLayerCollections(nn.Module):
277
  )
278
  return outputs
279
 
280
- class CustomFlaxAlbertLayerCollection(nn.Module):
281
  config: AlbertConfig
282
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
283
 
284
  def setup(self):
285
  self.layers = [
286
- CustomFlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
 
287
  ]
288
 
289
  def __call__(
@@ -293,57 +273,51 @@ class CustomFlaxAlbertLayerCollection(nn.Module):
293
  deterministic: bool = True,
294
  output_attentions: bool = False,
295
  output_hidden_states: bool = False,
296
- layer_id: int = None,
297
  interv_type: str = "swap",
298
  interv_dict: dict = {},
299
  ):
300
- layer_hidden_states = ()
301
- layer_attentions = ()
302
 
303
- for layer_index, albert_layer in enumerate(self.layers):
304
- layer_output = albert_layer(
 
 
305
  hidden_states,
306
  attention_mask,
307
  deterministic=deterministic,
308
  output_attentions=output_attentions,
309
- layer_id=layer_id,
 
310
  interv_type=interv_type,
311
  interv_dict=interv_dict,
312
  )
313
- hidden_states = layer_output[0]
314
 
315
  if output_attentions:
316
- layer_attentions = layer_attentions + (layer_output[1],)
317
 
318
  if output_hidden_states:
319
- layer_hidden_states = layer_hidden_states + (hidden_states,)
320
 
321
- outputs = (hidden_states,)
322
- if output_hidden_states:
323
- outputs = outputs + (layer_hidden_states,)
324
- if output_attentions:
325
- outputs = outputs + (layer_attentions,)
326
- return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
327
 
328
- class CustomFlaxAlbertLayer(nn.Module):
329
  config: AlbertConfig
330
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
331
 
332
  def setup(self):
333
- self.attention = CustomFlaxAlbertSelfAttention(self.config, dtype=self.dtype)
334
- self.ffn = nn.Dense(
335
- self.config.intermediate_size,
336
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
337
- dtype=self.dtype,
338
- )
339
- self.activation = ACT2FN[self.config.hidden_act]
340
- self.ffn_output = nn.Dense(
341
  self.config.hidden_size,
342
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
343
  dtype=self.dtype,
344
  )
345
- self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
346
- self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
347
 
348
  def __call__(
349
  self,
@@ -351,121 +325,147 @@ class CustomFlaxAlbertLayer(nn.Module):
351
  attention_mask,
352
  deterministic: bool = True,
353
  output_attentions: bool = False,
354
- layer_id: int = None,
 
355
  interv_type: str = "swap",
356
  interv_dict: dict = {},
357
  ):
358
- attention_outputs = self.attention(
 
359
  hidden_states,
360
  attention_mask,
361
  deterministic=deterministic,
362
  output_attentions=output_attentions,
363
- layer_id=layer_id,
364
  interv_type=interv_type,
365
  interv_dict=interv_dict,
366
  )
367
- attention_output = attention_outputs[0]
368
- ffn_output = self.ffn(attention_output)
369
- ffn_output = self.activation(ffn_output)
370
- ffn_output = self.ffn_output(ffn_output)
371
- ffn_output = self.dropout(ffn_output, deterministic=deterministic)
372
- hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
373
-
374
- outputs = (hidden_states,)
375
-
376
- if output_attentions:
377
- outputs += (attention_outputs[1],)
378
- return outputs
379
 
380
- class CustomFlaxAlbertSelfAttention(nn.Module):
381
  config: AlbertConfig
382
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
 
383
 
384
  def setup(self):
385
- if self.config.hidden_size % self.config.num_attention_heads != 0:
386
- raise ValueError(
387
- "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
388
- " : {self.config.num_attention_heads}"
 
 
 
 
389
  )
390
-
391
- self.query = nn.Dense(
392
- self.config.hidden_size,
393
- dtype=self.dtype,
394
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
395
- )
396
- self.key = nn.Dense(
397
- self.config.hidden_size,
398
- dtype=self.dtype,
399
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
400
- )
401
- self.value = nn.Dense(
402
- self.config.hidden_size,
403
- dtype=self.dtype,
404
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
405
- )
406
- self.dense = nn.Dense(
407
- self.config.hidden_size,
408
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
409
- dtype=self.dtype,
410
- )
411
- self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
412
- self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
413
 
414
  def __call__(
415
  self,
416
- hidden_states,
417
  attention_mask,
418
- deterministic=True,
 
 
419
  output_attentions: bool = False,
420
- layer_id: int = None,
 
421
  interv_type: str = "swap",
422
  interv_dict: dict = {},
423
  ):
424
- head_dim = self.config.hidden_size // self.config.num_attention_heads
 
 
425
 
426
- query_states = self.query(hidden_states).reshape(
427
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
428
- )
429
- value_states = self.value(hidden_states).reshape(
430
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
 
 
 
 
 
 
 
 
 
 
431
  )
432
- key_states = self.key(hidden_states).reshape(
433
- hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  )
435
 
436
- # Convert the boolean attention mask to an attention bias.
437
- if attention_mask is not None:
438
- # attention mask in the form of attention bias
439
- attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
440
- attention_bias = lax.select(
441
- attention_mask > 0,
442
- jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
443
- jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
444
- )
445
- else:
446
- attention_bias = None
447
 
448
- dropout_rng = None
449
- if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
450
- dropout_rng = self.make_rng("dropout")
451
 
452
- attn_weights = dot_product_attention_weights(
453
- query_states,
454
- key_states,
455
- bias=attention_bias,
456
- dropout_rng=dropout_rng,
457
- dropout_rate=self.config.attention_probs_dropout_prob,
458
- broadcast_dropout=True,
 
 
 
 
 
 
 
 
 
 
 
 
459
  deterministic=deterministic,
460
- dtype=self.dtype,
461
- precision=None,
 
 
 
462
  )
463
 
464
- attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
465
- attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
 
 
 
466
 
467
- projected_attn_output = self.dense(attn_output)
468
- projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
469
- layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
470
- outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
471
- return outputs
 
 
 
 
 
 
 
 
 
 
32
  overwrite_call_docstring,
33
  )
34
 
35
+ class CustomFlaxAlbertSelfAttention(nn.Module):
 
 
 
36
  config: AlbertConfig
37
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
38
 
39
  def setup(self):
40
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
41
+ raise ValueError(
42
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
43
+ " : {self.config.num_attention_heads}"
44
+ )
45
+
46
+ self.query = nn.Dense(
47
+ self.config.hidden_size,
48
+ dtype=self.dtype,
49
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
50
+ )
51
+ self.key = nn.Dense(
52
+ self.config.hidden_size,
53
+ dtype=self.dtype,
54
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
55
+ )
56
+ self.value = nn.Dense(
57
+ self.config.hidden_size,
58
+ dtype=self.dtype,
59
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
60
+ )
61
+ self.dense = nn.Dense(
62
+ self.config.hidden_size,
63
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
64
+ dtype=self.dtype,
65
+ )
66
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
67
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
68
 
69
  def __call__(
70
  self,
71
+ hidden_states,
72
  attention_mask,
73
+ deterministic=True,
 
 
74
  output_attentions: bool = False,
75
+ layer_id: int = None,
 
76
  interv_type: str = "swap",
77
  interv_dict: dict = {},
78
  ):
79
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ query_states = self.query(hidden_states).reshape(
82
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
83
+ )
84
+ value_states = self.value(hidden_states).reshape(
85
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
86
+ )
87
+ key_states = self.key(hidden_states).reshape(
88
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
89
  )
90
 
91
+ # Convert the boolean attention mask to an attention bias.
92
+ if attention_mask is not None:
93
+ # attention mask in the form of attention bias
94
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
95
+ attention_bias = lax.select(
96
+ attention_mask > 0,
97
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
98
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
 
 
 
 
 
 
99
  )
 
100
  else:
101
+ attention_bias = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ dropout_rng = None
104
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
105
+ dropout_rng = self.make_rng("dropout")
106
 
107
+ attn_weights = dot_product_attention_weights(
108
+ query_states,
109
+ key_states,
110
+ bias=attention_bias,
111
+ dropout_rng=dropout_rng,
112
+ dropout_rate=self.config.attention_probs_dropout_prob,
113
+ broadcast_dropout=True,
114
  deterministic=deterministic,
115
+ dtype=self.dtype,
116
+ precision=None,
 
 
 
117
  )
 
 
 
 
 
 
118
 
119
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
120
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
 
 
 
121
 
122
+ projected_attn_output = self.dense(attn_output)
123
+ projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
124
+ layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
125
+ outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
126
+ return outputs
 
127
 
128
+ class CustomFlaxAlbertLayer(nn.Module):
129
  config: AlbertConfig
130
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
131
 
132
  def setup(self):
133
+ self.attention = CustomFlaxAlbertSelfAttention(self.config, dtype=self.dtype)
134
+ self.ffn = nn.Dense(
135
+ self.config.intermediate_size,
136
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
137
+ dtype=self.dtype,
138
+ )
139
+ self.activation = ACT2FN[self.config.hidden_act]
140
+ self.ffn_output = nn.Dense(
141
  self.config.hidden_size,
142
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
143
  dtype=self.dtype,
144
  )
145
+ self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
146
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
147
 
148
  def __call__(
149
  self,
 
151
  attention_mask,
152
  deterministic: bool = True,
153
  output_attentions: bool = False,
154
+ layer_id: int = None,
 
155
  interv_type: str = "swap",
156
  interv_dict: dict = {},
157
  ):
158
+ attention_outputs = self.attention(
 
159
  hidden_states,
160
  attention_mask,
161
  deterministic=deterministic,
162
  output_attentions=output_attentions,
163
+ layer_id=layer_id,
164
  interv_type=interv_type,
165
  interv_dict=interv_dict,
166
  )
167
+ attention_output = attention_outputs[0]
168
+ ffn_output = self.ffn(attention_output)
169
+ ffn_output = self.activation(ffn_output)
170
+ ffn_output = self.ffn_output(ffn_output)
171
+ ffn_output = self.dropout(ffn_output, deterministic=deterministic)
172
+ hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
173
 
174
+ outputs = (hidden_states,)
175
+
176
+ if output_attentions:
177
+ outputs += (attention_outputs[1],)
178
+ return outputs
179
+
180
+ class CustomFlaxAlbertLayerCollection(nn.Module):
181
  config: AlbertConfig
182
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
183
 
184
  def setup(self):
185
  self.layers = [
186
+ CustomFlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
 
187
  ]
188
 
189
  def __call__(
 
193
  deterministic: bool = True,
194
  output_attentions: bool = False,
195
  output_hidden_states: bool = False,
196
+ layer_id: int = None,
197
  interv_type: str = "swap",
198
  interv_dict: dict = {},
199
  ):
200
+ layer_hidden_states = ()
201
+ layer_attentions = ()
202
 
203
+ for layer_index, albert_layer in enumerate(self.layers):
204
+ layer_output = albert_layer(
 
 
205
  hidden_states,
206
  attention_mask,
207
  deterministic=deterministic,
208
  output_attentions=output_attentions,
209
+ layer_id=layer_id,
 
210
  interv_type=interv_type,
211
  interv_dict=interv_dict,
212
  )
213
+ hidden_states = layer_output[0]
214
 
215
  if output_attentions:
216
+ layer_attentions = layer_attentions + (layer_output[1],)
217
 
218
  if output_hidden_states:
219
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
220
 
221
+ outputs = (hidden_states,)
222
+ if output_hidden_states:
223
+ outputs = outputs + (layer_hidden_states,)
224
+ if output_attentions:
225
+ outputs = outputs + (layer_attentions,)
226
+ return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
227
 
228
  class CustomFlaxAlbertLayerCollections(nn.Module):
229
  config: AlbertConfig
 
256
  )
257
  return outputs
258
 
259
+ class CustomFlaxAlbertLayerGroups(nn.Module):
260
  config: AlbertConfig
261
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
262
 
263
  def setup(self):
264
  self.layers = [
265
+ CustomFlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
266
+ for i in range(self.config.num_hidden_groups)
267
  ]
268
 
269
  def __call__(
 
273
  deterministic: bool = True,
274
  output_attentions: bool = False,
275
  output_hidden_states: bool = False,
276
+ return_dict: bool = True,
277
  interv_type: str = "swap",
278
  interv_dict: dict = {},
279
  ):
280
+ all_attentions = () if output_attentions else None
281
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
282
 
283
+ for i in range(self.config.num_hidden_layers):
284
+ # Index of the hidden group
285
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
286
+ layer_group_output = self.layers[group_idx](
287
  hidden_states,
288
  attention_mask,
289
  deterministic=deterministic,
290
  output_attentions=output_attentions,
291
+ output_hidden_states=output_hidden_states,
292
+ layer_id=i,
293
  interv_type=interv_type,
294
  interv_dict=interv_dict,
295
  )
296
+ hidden_states = layer_group_output[0]
297
 
298
  if output_attentions:
299
+ all_attentions = all_attentions + layer_group_output[-1]
300
 
301
  if output_hidden_states:
302
+ all_hidden_states = all_hidden_states + (hidden_states,)
303
 
304
+ if not return_dict:
305
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
306
+ return FlaxBaseModelOutput(
307
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
308
+ )
 
309
 
310
+ class CustomFlaxAlbertEncoder(nn.Module):
311
  config: AlbertConfig
312
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
313
 
314
  def setup(self):
315
+ self.embedding_hidden_mapping_in = nn.Dense(
 
 
 
 
 
 
 
316
  self.config.hidden_size,
317
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
318
  dtype=self.dtype,
319
  )
320
+ self.albert_layer_groups = CustomFlaxAlbertLayerGroups(self.config, dtype=self.dtype)
 
321
 
322
  def __call__(
323
  self,
 
325
  attention_mask,
326
  deterministic: bool = True,
327
  output_attentions: bool = False,
328
+ output_hidden_states: bool = False,
329
+ return_dict: bool = True,
330
  interv_type: str = "swap",
331
  interv_dict: dict = {},
332
  ):
333
+ hidden_states = self.embedding_hidden_mapping_in(hidden_states)
334
+ return self.albert_layer_groups(
335
  hidden_states,
336
  attention_mask,
337
  deterministic=deterministic,
338
  output_attentions=output_attentions,
339
+ output_hidden_states=output_hidden_states,
340
  interv_type=interv_type,
341
  interv_dict=interv_dict,
342
  )
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
+ class CustomFlaxAlbertModule(nn.Module):
345
  config: AlbertConfig
346
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
347
+ add_pooling_layer: bool = True
348
 
349
  def setup(self):
350
+ self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
351
+ self.encoder = CustomFlaxAlbertEncoder(self.config, dtype=self.dtype)
352
+ if self.add_pooling_layer:
353
+ self.pooler = nn.Dense(
354
+ self.config.hidden_size,
355
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
356
+ dtype=self.dtype,
357
+ name="pooler",
358
  )
359
+ self.pooler_activation = nn.tanh
360
+ else:
361
+ self.pooler = None
362
+ self.pooler_activation = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  def __call__(
365
  self,
366
+ input_ids,
367
  attention_mask,
368
+ token_type_ids: Optional[np.ndarray] = None,
369
+ position_ids: Optional[np.ndarray] = None,
370
+ deterministic: bool = True,
371
  output_attentions: bool = False,
372
+ output_hidden_states: bool = False,
373
+ return_dict: bool = True,
374
  interv_type: str = "swap",
375
  interv_dict: dict = {},
376
  ):
377
+ # make sure `token_type_ids` is correctly initialized when not passed
378
+ if token_type_ids is None:
379
+ token_type_ids = jnp.zeros_like(input_ids)
380
 
381
+ # make sure `position_ids` is correctly initialized when not passed
382
+ if position_ids is None:
383
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
384
+
385
+ hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)
386
+
387
+ outputs = self.encoder(
388
+ hidden_states,
389
+ attention_mask,
390
+ deterministic=deterministic,
391
+ output_attentions=output_attentions,
392
+ output_hidden_states=output_hidden_states,
393
+ return_dict=return_dict,
394
+ interv_type=interv_type,
395
+ interv_dict=interv_dict,
396
  )
397
+ hidden_states = outputs[0]
398
+ if self.add_pooling_layer:
399
+ pooled = self.pooler(hidden_states[:, 0])
400
+ pooled = self.pooler_activation(pooled)
401
+ else:
402
+ pooled = None
403
+
404
+ if not return_dict:
405
+ # if pooled is None, don't return it
406
+ if pooled is None:
407
+ return (hidden_states,) + outputs[1:]
408
+ return (hidden_states, pooled) + outputs[1:]
409
+
410
+ return FlaxBaseModelOutputWithPooling(
411
+ last_hidden_state=hidden_states,
412
+ pooler_output=pooled,
413
+ hidden_states=outputs.hidden_states,
414
+ attentions=outputs.attentions,
415
  )
416
 
417
+ class CustomFlaxAlbertForMaskedLMModule(nn.Module):
418
+ config: AlbertConfig
419
+ dtype: jnp.dtype = jnp.float32
 
 
 
 
 
 
 
 
420
 
421
+ def setup(self):
422
+ self.albert = CustomFlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
423
+ self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)
424
 
425
+ def __call__(
426
+ self,
427
+ input_ids,
428
+ attention_mask,
429
+ token_type_ids,
430
+ position_ids,
431
+ deterministic: bool = True,
432
+ output_attentions: bool = False,
433
+ output_hidden_states: bool = False,
434
+ return_dict: bool = True,
435
+ interv_type: str = "swap",
436
+ interv_dict: dict = {},
437
+ ):
438
+ # Model
439
+ outputs = self.albert(
440
+ input_ids,
441
+ attention_mask,
442
+ token_type_ids,
443
+ position_ids,
444
  deterministic=deterministic,
445
+ output_attentions=output_attentions,
446
+ output_hidden_states=output_hidden_states,
447
+ return_dict=return_dict,
448
+ interv_type=interv_type,
449
+ interv_dict=interv_dict,
450
  )
451
 
452
+ hidden_states = outputs[0]
453
+ if self.config.tie_word_embeddings:
454
+ shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
455
+ else:
456
+ shared_embedding = None
457
 
458
+ # Compute the prediction scores
459
+ logits = self.predictions(hidden_states, shared_embedding=shared_embedding)
460
+
461
+ if not return_dict:
462
+ return (logits,) + outputs[1:]
463
+
464
+ return FlaxMaskedLMOutput(
465
+ logits=logits,
466
+ hidden_states=outputs.hidden_states,
467
+ attentions=outputs.attentions,
468
+ )
469
+
470
+ class CustomFlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
471
+ module_class = CustomFlaxAlbertForMaskedLMModule