File size: 18,020 Bytes
c92f27c
c1ed878
c92f27c
 
 
 
 
 
 
 
 
 
 
 
 
0c71efa
c92f27c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4fb159
c92f27c
a4fb159
c92f27c
 
a4fb159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c92f27c
 
 
a4fb159
c92f27c
a4fb159
c92f27c
a4fb159
c92f27c
 
 
a4fb159
c92f27c
a4fb159
 
 
 
 
 
 
 
c92f27c
 
c1ed878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4fb159
 
 
 
 
 
 
 
c92f27c
 
a4fb159
c92f27c
a4fb159
 
 
c92f27c
a4fb159
 
 
 
 
 
 
c92f27c
a4fb159
 
c92f27c
 
a4fb159
 
c92f27c
a4fb159
 
 
 
 
c92f27c
a4fb159
c92f27c
 
 
 
a4fb159
 
 
 
 
 
 
 
c92f27c
 
 
 
a4fb159
 
c92f27c
 
 
 
 
 
 
a4fb159
c92f27c
 
 
a4fb159
c92f27c
 
 
 
a4fb159
c92f27c
 
 
a4fb159
 
 
 
 
 
c92f27c
a4fb159
 
 
 
 
 
 
c92f27c
 
 
 
 
a4fb159
c92f27c
 
 
 
 
 
 
 
 
a4fb159
c92f27c
 
 
a4fb159
 
c92f27c
a4fb159
 
c92f27c
 
 
 
a4fb159
c92f27c
 
 
a4fb159
c92f27c
 
a4fb159
c92f27c
 
a4fb159
c92f27c
a4fb159
 
 
 
 
 
c92f27c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4fb159
c92f27c
 
 
 
 
a4fb159
 
c92f27c
 
 
 
 
 
 
 
 
a4fb159
c92f27c
 
 
a4fb159
 
c92f27c
a4fb159
 
 
 
c92f27c
 
 
 
a4fb159
 
c92f27c
 
 
a4fb159
c92f27c
 
a4fb159
c92f27c
 
a4fb159
c92f27c
a4fb159
 
 
 
 
c92f27c
a4fb159
c92f27c
 
 
 
a4fb159
c92f27c
 
 
 
a4fb159
c92f27c
 
 
 
 
 
 
a4fb159
 
c92f27c
 
 
a4fb159
 
c92f27c
 
 
 
a4fb159
c92f27c
 
 
 
a4fb159
c92f27c
 
a4fb159
c92f27c
 
a4fb159
 
 
 
 
 
 
 
c92f27c
a4fb159
 
 
 
c92f27c
 
 
a4fb159
c92f27c
a4fb159
 
 
c92f27c
a4fb159
 
c92f27c
 
 
a4fb159
 
 
c92f27c
a4fb159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c92f27c
a4fb159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c92f27c
 
a4fb159
 
 
c92f27c
a4fb159
 
 
c92f27c
a4fb159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c92f27c
a4fb159
 
 
 
 
c92f27c
 
a4fb159
 
 
 
 
c92f27c
a4fb159
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
from typing import Callable, Optional, Tuple
from copy import deepcopy

import numpy as np

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.linen.attention import dot_product_attention_weights
from flax.traverse_util import flatten_dict, unflatten_dict
from jax import lax

from transformers import AlbertConfig
from transformers.models.albert.modeling_flax_albert import FlaxAlbertOnlyMLMHead, FlaxAlbertEmbeddings, FlaxAlbertPreTrainedModel
from transformers.modeling_flax_outputs import (
    FlaxBaseModelOutput,
    FlaxBaseModelOutputWithPooling,
    FlaxMaskedLMOutput,
    FlaxMultipleChoiceModelOutput,
    FlaxQuestionAnsweringModelOutput,
    FlaxSequenceClassifierOutput,
    FlaxTokenClassifierOutput,
)
from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging

from transformers.modeling_flax_utils import (
    ACT2FN,
    FlaxPreTrainedModel,
    append_call_sample_docstring,
    append_replace_return_docstrings,
    overwrite_call_docstring,
)

class CustomFlaxAlbertSelfAttention(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        if self.config.hidden_size % self.config.num_attention_heads != 0:
            raise ValueError(
                "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
                "                   : {self.config.num_attention_heads}"
            )

        self.query = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        self.key = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        self.value = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
        )
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic=True,
        output_attentions: bool = False,
        layer_id: int = None,
        interv_type: str = "swap",
        interv_dict: dict = {},
    ):
        head_dim = self.config.hidden_size // self.config.num_attention_heads

        query_states = self.query(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        value_states = self.value(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        key_states = self.key(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )

        reps = {
                'lay': hidden_states,
                'qry': query_states,
                'key': key_states,
                'val': value_states,
                }
        if layer_id in interv_dict:
            interv = interv_dict[layer_id]
            for rep_name in ['lay','qry','key','val']:
                if rep_name in interv:
                    new_state = deepcopy(reps[rep_name])
                    for head_id, pos, swap_ids in interv[rep_name]:
                        new_state[swap_ids[0],pos,head_id] = reps[rep_name][swap_ids[1],pos,head_id]
                        new_state[swap_ids[1],pos,head_id] = reps[rep_name][swap_ids[0],pos,head_id]
                    reps[rep_name] = deepcopy(new_state)

        hidden_states = deepcopy(reps['lay'])
        query_states = deepcopy(reps['qry'])
        key_states = deepcopy(reps['key'])
        value_states = deepcopy(reps['val'])

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_probs_dropout_prob,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))

        projected_attn_output = self.dense(attn_output)
        projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic)
        layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states)
        outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,)
        return outputs

class CustomFlaxAlbertLayer(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.attention = CustomFlaxAlbertSelfAttention(self.config, dtype=self.dtype)
        self.ffn = nn.Dense(
            self.config.intermediate_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        self.activation = ACT2FN[self.config.hidden_act]
        self.ffn_output = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)

    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        layer_id: int = None,
        interv_type: str = "swap",
        interv_dict: dict = {},
    ):
        attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            layer_id=layer_id,
            interv_type=interv_type,
            interv_dict=interv_dict,
        )
        attention_output = attention_outputs[0]
        ffn_output = self.ffn(attention_output)
        ffn_output = self.activation(ffn_output)
        ffn_output = self.ffn_output(ffn_output)
        ffn_output = self.dropout(ffn_output, deterministic=deterministic)
        hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attention_outputs[1],)
        return outputs

class CustomFlaxAlbertLayerCollection(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.layers = [
            CustomFlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num)
        ]

    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        layer_id: int = None,
        interv_type: str = "swap",
        interv_dict: dict = {},
    ):
        layer_hidden_states = ()
        layer_attentions = ()

        for layer_index, albert_layer in enumerate(self.layers):
            layer_output = albert_layer(
                hidden_states,
                attention_mask,
                deterministic=deterministic,
                output_attentions=output_attentions,
                layer_id=layer_id,
                interv_type=interv_type,
                interv_dict=interv_dict,
            )
            hidden_states = layer_output[0]

            if output_attentions:
                layer_attentions = layer_attentions + (layer_output[1],)

            if output_hidden_states:
                layer_hidden_states = layer_hidden_states + (hidden_states,)

        outputs = (hidden_states,)
        if output_hidden_states:
            outputs = outputs + (layer_hidden_states,)
        if output_attentions:
            outputs = outputs + (layer_attentions,)
        return outputs  # last-layer hidden state, (layer hidden states), (layer attentions)

class CustomFlaxAlbertLayerCollections(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    layer_index: Optional[str] = None

    def setup(self):
        self.albert_layers = CustomFlaxAlbertLayerCollection(self.config, dtype=self.dtype)

    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        layer_id: int = None,
        interv_type: str = "swap",
        interv_dict: dict = {},
    ):
        outputs = self.albert_layers(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            layer_id=layer_id,
            interv_type=interv_type,
            interv_dict=interv_dict,
        )
        return outputs

class CustomFlaxAlbertLayerGroups(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.layers = [
            CustomFlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype)
            for i in range(self.config.num_hidden_groups)
        ]

    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        interv_type: str = "swap",
        interv_dict: dict = {},
    ):
        all_attentions = () if output_attentions else None
        all_hidden_states = (hidden_states,) if output_hidden_states else None

        for i in range(self.config.num_hidden_layers):
            # Index of the hidden group
            group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
            layer_group_output = self.layers[group_idx](
                hidden_states,
                attention_mask,
                deterministic=deterministic,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                layer_id=i,
                interv_type=interv_type,
                interv_dict=interv_dict,
            )
            hidden_states = layer_group_output[0]

            if output_attentions:
                all_attentions = all_attentions + layer_group_output[-1]

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
        return FlaxBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )

class CustomFlaxAlbertEncoder(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        self.embedding_hidden_mapping_in = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
            dtype=self.dtype,
        )
        self.albert_layer_groups = CustomFlaxAlbertLayerGroups(self.config, dtype=self.dtype)

    def __call__(
        self,
        hidden_states,
        attention_mask,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        interv_type: str = "swap",
        interv_dict: dict = {},
    ):
        hidden_states = self.embedding_hidden_mapping_in(hidden_states)
        return self.albert_layer_groups(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interv_type=interv_type,
            interv_dict=interv_dict,
        )

class CustomFlaxAlbertModule(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    add_pooling_layer: bool = True

    def setup(self):
        self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype)
        self.encoder = CustomFlaxAlbertEncoder(self.config, dtype=self.dtype)
        if self.add_pooling_layer:
            self.pooler = nn.Dense(
                self.config.hidden_size,
                kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
                dtype=self.dtype,
                name="pooler",
            )
            self.pooler_activation = nn.tanh
        else:
            self.pooler = None
            self.pooler_activation = None

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids: Optional[np.ndarray] = None,
        position_ids: Optional[np.ndarray] = None,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        interv_type: str = "swap",
        interv_dict: dict = {},
    ):
        # make sure `token_type_ids` is correctly initialized when not passed
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        # make sure `position_ids` is correctly initialized when not passed
        if position_ids is None:
            position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

        hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic)

        outputs = self.encoder(
            hidden_states,
            attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            interv_type=interv_type,
            interv_dict=interv_dict,
        )
        hidden_states = outputs[0]
        if self.add_pooling_layer:
            pooled = self.pooler(hidden_states[:, 0])
            pooled = self.pooler_activation(pooled)
        else:
            pooled = None

        if not return_dict:
            # if pooled is None, don't return it
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]

        return FlaxBaseModelOutputWithPooling(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

class CustomFlaxAlbertForMaskedLMModule(nn.Module):
    config: AlbertConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.albert = CustomFlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
        self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        interv_type: str = "swap",
        interv_dict: dict = {},
    ):
        # Model
        outputs = self.albert(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            interv_type=interv_type,
            interv_dict=interv_dict,
        )

        hidden_states = outputs[0]
        if self.config.tie_word_embeddings:
            shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
        else:
            shared_embedding = None

        # Compute the prediction scores
        logits = self.predictions(hidden_states, shared_embedding=shared_embedding)

        if not return_dict:
            return (logits,) + outputs[1:]

        return FlaxMaskedLMOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

class CustomFlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
    module_class = CustomFlaxAlbertForMaskedLMModule