prikmm commited on
Commit
10c0b09
1 Parent(s): 498c5e0

Adds flax port for IndicTrans2

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. flax_model.msgpack +3 -0
  3. modeling_flax_indictrans.py +1373 -0
config.json CHANGED
@@ -7,7 +7,8 @@
7
  ],
8
  "auto_map": {
9
  "AutoConfig": "configuration_indictrans.IndicTransConfig",
10
- "AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration"
 
11
  },
12
  "attention_dropout": 0.0,
13
  "bos_token_id": 0,
 
7
  ],
8
  "auto_map": {
9
  "AutoConfig": "configuration_indictrans.IndicTransConfig",
10
+ "AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration",
11
+ "FlaxAutoModelForSeq2SeqLM": "modeling_flax_indictrans.FlaxIndicTransForConditionalGeneration"
12
  },
13
  "attention_dropout": 0.0,
14
  "bos_token_id": 0,
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d1470cfb17a2881611d9d2dda3e1ab3457914ce16a099d8dc5466fe202e85ad
3
+ size 847132908
modeling_flax_indictrans.py ADDED
@@ -0,0 +1,1373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Flax IndicTrans model."""
16
+
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union, Callable
20
+ from functools import partial
21
+
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
26
+ from flax.linen import combine_masks, make_causal_mask
27
+ from flax.linen.attention import dot_product_attention_weights
28
+ from flax.traverse_util import flatten_dict, unflatten_dict
29
+ from jax import lax
30
+ from jax.random import PRNGKey
31
+
32
+ from transformers.modeling_flax_outputs import (
33
+ FlaxBaseModelOutput,
34
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
35
+ FlaxCausalLMOutputWithCrossAttentions,
36
+ FlaxSeq2SeqLMOutput,
37
+ FlaxSeq2SeqModelOutput,
38
+ )
39
+ from transformers.modeling_flax_utils import (
40
+ ACT2FN,
41
+ FlaxPreTrainedModel,
42
+ append_call_sample_docstring,
43
+ append_replace_return_docstrings,
44
+ overwrite_call_docstring,
45
+ )
46
+ from configuration_indictrans import IndicTransConfig
47
+ from transformers.utils import logging
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CONFIG_FOR_DOC = "IndicTransConfig"
53
+
54
+ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
55
+
56
+ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
57
+ """
58
+ Shift input ids one token to the right.
59
+ """
60
+ shifted_input_ids = jnp.zeros_like(input_ids)
61
+ shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
62
+ shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
63
+
64
+ if pad_token_id is None:
65
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
66
+ # replace possible -100 values in labels by `pad_token_id`
67
+ shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
68
+
69
+ return shifted_input_ids
70
+
71
+
72
+ class FlaxIndicTransSinusoidalPositionalEmbedding(nn.Module):
73
+ """This module produces sinusoidal positional embeddings of any length."""
74
+ num_positions: int
75
+ embedding_dim: int
76
+ padding_idx: Optional[int] = None
77
+
78
+ # IndicTrans is set up so that if padding_idx is specified then offset the embedding ids by 2
79
+ # and adjust num_embeddings appropriately. Other models don't have this hack
80
+ offset: int = 2
81
+
82
+ def setup(self) -> None:
83
+ self.weights = self._make_weights(self.num_positions + self.offset, self.embedding_dim, padding_idx=self.padding_idx)
84
+
85
+ def _make_weights(
86
+ self,
87
+ num_embeddings: int,
88
+ embedding_dim: int,
89
+ existing_weights: Optional[jnp.array] = None,
90
+ padding_idx: Optional[int] = None
91
+ ):
92
+ emb_weights = self._get_embedding(num_embeddings, embedding_dim, padding_idx)
93
+
94
+ if existing_weights is not None:
95
+ # Convert emb_weights to the same dtype as existing_weights
96
+ emb_weights = emb_weights.astype(existing_weights.dtype)
97
+
98
+ return emb_weights
99
+
100
+ def _get_embedding(
101
+ self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
102
+ ):
103
+ """
104
+ Build sinusoidal embeddings.
105
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
106
+ "Attention Is All You Need".
107
+ """
108
+ half_dim = embedding_dim // 2
109
+ emb = math.log(10000) / (half_dim - 1)
110
+ emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
111
+ emb = jnp.arange(num_embeddings, dtype=jnp.float32).reshape(-1, 1) * emb.reshape(1, -1)
112
+ emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1).reshape(num_embeddings, -1)
113
+
114
+ if embedding_dim % 2 == 1:
115
+ # zero pad
116
+ emb = jnp.concatenate([emb, jnp.zeros((num_embeddings, 1), dtype=emb.dtype)], axis=1)
117
+
118
+ if padding_idx is not None:
119
+ emb = emb.at[padding_idx].set(0)
120
+
121
+ return emb
122
+
123
+ def __call__(
124
+ self,
125
+ input_ids: jnp.array = None,
126
+ inputs_embeds: jnp.array = None,
127
+ past_key_values_length: int = 0
128
+ ):
129
+ if input_ids is not None:
130
+ bsz, seq_len = input_ids.shape
131
+ position_ids = self._create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
132
+ else:
133
+ bsz, seq_len = inputs_embeds.shape[:-1]
134
+ position_ids = self._create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
135
+
136
+ # Expand embeddings if needed
137
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
138
+ if max_pos > self.weights.shape[0]:
139
+ self.weights = self.make_weights(max_pos + self.offset, self.embedding_dim, self.weights, self.padding_idx)
140
+
141
+ return self.weights[position_ids.ravel()].reshape(bsz, seq_len, -1)
142
+
143
+ def _create_position_ids_from_input_ids(
144
+ self, input_ids, padding_idx, past_key_values_length=0
145
+ ):
146
+ """
147
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
148
+ are ignored. This is a JAX conversion of the PyTorch function.
149
+ """
150
+ mask = (input_ids != padding_idx)
151
+ incremental_indices = (jnp.cumsum(mask, axis=1) + past_key_values_length) * mask
152
+ return incremental_indices + padding_idx
153
+
154
+ def _create_position_ids_from_inputs_embeds(
155
+ self, inputs_embeds, past_key_values_length
156
+ ):
157
+ """
158
+ Generate sequential position ids from input embeddings.
159
+ Args:
160
+ inputs_embeds: jnp.array (JAX array)
161
+ past_key_values_length: int
162
+ Returns:
163
+ jnp.array: Position IDs corresponding to the inputs.
164
+ """
165
+ input_shape = inputs_embeds.shape[:-1]
166
+ sequence_length = input_shape[1]
167
+
168
+ position_ids = jnp.arange(self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=jnp.int64)
169
+ return jnp.expand_dims(position_ids, axis=0).repeat(input_shape[0], axis=0) + past_key_values_length
170
+
171
+
172
+ class FlaxIndicTransAttention(nn.Module):
173
+ config: IndicTransConfig
174
+ embed_dim: int
175
+ num_heads: int
176
+ dropout: float = 0.0
177
+ causal: bool = False
178
+ bias: bool = True
179
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
180
+
181
+ def setup(self) -> None:
182
+
183
+ self.head_dim = self.embed_dim // self.num_heads
184
+ if self.head_dim * self.num_heads != self.embed_dim:
185
+ raise ValueError(
186
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
187
+ f" and `num_heads`: {self.num_heads})."
188
+ )
189
+
190
+ # Not required in Flax Module as `dot_product_attention_weights` handles scaling internally.
191
+ # For more details, check: https://flax.readthedocs.io/en/latest/_modules/flax/linen/attention.html#dot_product_attention_weights
192
+ # self.scaling
193
+
194
+ dense = partial(
195
+ nn.Dense,
196
+ self.embed_dim,
197
+ use_bias=self.bias,
198
+ dtype=self.dtype,
199
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
200
+ )
201
+
202
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
203
+ self.out_proj = dense()
204
+
205
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
206
+
207
+ if self.causal:
208
+ self.causal_mask = make_causal_mask(
209
+ jnp.ones((1, self.config.max_source_positions), dtype="bool"), dtype="bool"
210
+ )
211
+
212
+ def _split_heads(self, hidden_states):
213
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
214
+
215
+ def _merge_heads(self, hidden_states):
216
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
217
+
218
+ @nn.compact
219
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
220
+ """
221
+ This function takes projected key, value states from a single input token and concatenates the states to cached
222
+ states from previous steps. This function is slighly adapted from the official Flax repository:
223
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
224
+ """
225
+ # detect if we're initializing by absence of existing cache data.
226
+ is_initialized = self.has_variable("cache", "cached_key")
227
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
228
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
229
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
230
+
231
+ if is_initialized:
232
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
233
+ # update key, value caches with our new 1d spatial slices
234
+ cur_index = cache_index.value
235
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
236
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
237
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
238
+ cached_key.value = key
239
+ cached_value.value = value
240
+ num_updated_cache_vectors = query.shape[1]
241
+ cache_index.value = cache_index.value + num_updated_cache_vectors
242
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
243
+ pad_mask = jnp.broadcast_to(
244
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
245
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
246
+ )
247
+ attention_mask = combine_masks(pad_mask, attention_mask)
248
+ return key, value, attention_mask
249
+
250
+ def __call__(
251
+ self,
252
+ hidden_states: jnp.ndarray,
253
+ key_value_states: Optional[jnp.ndarray] = None,
254
+ attention_mask: Optional[jnp.ndarray] = None,
255
+ init_cache: bool = False,
256
+ deterministic: bool = True,
257
+ ) -> Tuple[jnp.ndarray]:
258
+ """Input shape: Batch x Time x Channel"""
259
+
260
+ # if key_value_states are provided this layer is used as a cross-attention layer
261
+ # for the decoder
262
+ is_cross_attention = key_value_states is not None
263
+ batch_size = hidden_states.shape[0]
264
+
265
+ # get query proj
266
+ query_states = self.q_proj(hidden_states) # Scaling is handled internally by `dot_product_attention_weights`.
267
+ # get key, value proj
268
+ if is_cross_attention:
269
+ # cross_attentions
270
+ key_states = self.k_proj(key_value_states)
271
+ value_states = self.v_proj(key_value_states)
272
+ else:
273
+ # self_attention
274
+ key_states = self.k_proj(hidden_states)
275
+ value_states = self.v_proj(hidden_states)
276
+
277
+ query_states = self._split_heads(query_states)
278
+ key_states = self._split_heads(key_states)
279
+ value_states = self._split_heads(value_states)
280
+
281
+ # handle cache prepare causal attention mask
282
+ if self.causal:
283
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
284
+ if self.has_variable("cache", "cached_key"):
285
+ mask_shift = self.variables["cache"]["cache_index"]
286
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
287
+ causal_mask = lax.dynamic_slice(
288
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
289
+ )
290
+ else:
291
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
292
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
293
+
294
+ # combine masks if needed
295
+ if attention_mask is not None and self.causal:
296
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
297
+ attention_mask = combine_masks(attention_mask, causal_mask)
298
+ elif self.causal:
299
+ attention_mask = causal_mask
300
+ elif attention_mask is not None:
301
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
302
+
303
+ # During fast autoregressive decoding, we feed one position at a time,
304
+ # and cache the keys and values step by step.
305
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
306
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
307
+ key_states, value_states, query_states, attention_mask
308
+ )
309
+
310
+ # Convert the boolean attention mask to an attention bias.
311
+ if attention_mask is not None:
312
+ # attention mask in the form of attention bias
313
+ attention_bias = lax.select(
314
+ attention_mask > 0,
315
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
316
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
317
+ )
318
+ else:
319
+ attention_bias = None
320
+
321
+ dropout_rng = None
322
+ if not deterministic and self.dropout > 0.0:
323
+ dropout_rng = self.make_rng("dropout")
324
+
325
+ attn_weights = dot_product_attention_weights(
326
+ query_states,
327
+ key_states,
328
+ bias=attention_bias,
329
+ dropout_rng=dropout_rng,
330
+ dropout_rate=self.dropout,
331
+ broadcast_dropout=True,
332
+ deterministic=deterministic,
333
+ dtype=self.dtype,
334
+ precision="high",
335
+ )
336
+
337
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
338
+ attn_output = self._merge_heads(attn_output)
339
+ attn_output = self.out_proj(attn_output)
340
+
341
+ return attn_output, attn_weights
342
+
343
+
344
+ class FlaxIndicTransEncoderLayer(nn.Module):
345
+ config: IndicTransConfig
346
+ dtype: jnp.dtype = jnp.float32
347
+
348
+ def setup(self) -> None:
349
+ self.embed_dim = self.config.encoder_embed_dim
350
+ self.self_attn = FlaxIndicTransAttention(
351
+ config=self.config,
352
+ embed_dim=self.embed_dim,
353
+ num_heads=self.config.encoder_attention_heads,
354
+ dropout=self.config.attention_dropout,
355
+ dtype=self.dtype,
356
+ )
357
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
358
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
359
+ self.activation_fn = ACT2FN[self.config.activation_function]
360
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
361
+ self.fc1 = nn.Dense(
362
+ self.config.encoder_ffn_dim,
363
+ dtype=self.dtype,
364
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
365
+ )
366
+ self.fc2 = nn.Dense(
367
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
368
+ )
369
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
370
+ self.normalize_before = self.config.encoder_normalize_before
371
+
372
+ def __call__(
373
+ self,
374
+ hidden_states: jnp.ndarray,
375
+ attention_mask: jnp.ndarray,
376
+ output_attentions: bool = True,
377
+ deterministic: bool = True,
378
+ ) -> Tuple[jnp.ndarray]:
379
+ residual = hidden_states
380
+ if self.normalize_before:
381
+ hidden_states = self.self_attn_layer_norm(hidden_states)
382
+ hidden_states, attn_weights = self.self_attn(
383
+ hidden_states=hidden_states,
384
+ attention_mask=attention_mask
385
+
386
+ )
387
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
388
+ hidden_states = residual + hidden_states
389
+ if not self.normalize_before:
390
+ hidden_states = self.attn_layer_norm(hidden_states)
391
+
392
+ residual = hidden_states
393
+ if self.normalize_before:
394
+ hidden_states = self.final_layer_norm(hidden_states)
395
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
396
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
397
+ hidden_states = self.fc2(hidden_states)
398
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
399
+ hidden_states = residual + hidden_states
400
+ if not self.normalize_before:
401
+ hidden_states = self.final_layer_norm(hidden_states)
402
+
403
+ outputs = (hidden_states,)
404
+
405
+ if output_attentions:
406
+ outputs += (attn_weights,)
407
+
408
+ return outputs
409
+
410
+
411
+ class FlaxIndicTransEncoderLayerCollection(nn.Module):
412
+ config: IndicTransConfig
413
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
414
+
415
+ def setup(self):
416
+ self.layers = [
417
+ FlaxIndicTransEncoderLayer(self.config, name=str(i), dtype=self.dtype)
418
+ for i in range(self.config.encoder_layers)
419
+ ]
420
+ self.layerdrop = self.config.encoder_layerdrop
421
+
422
+ def __call__(
423
+ self,
424
+ hidden_states,
425
+ attention_mask,
426
+ deterministic: bool = True,
427
+ output_attentions: bool = False,
428
+ output_hidden_states: bool = False,
429
+ return_dict: bool = True,
430
+ ):
431
+ all_attentions = () if output_attentions else None
432
+ all_hidden_states = () if output_hidden_states else None
433
+
434
+ for encoder_layer in self.layers:
435
+ if output_hidden_states:
436
+ all_hidden_states = all_hidden_states + (hidden_states,)
437
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
438
+ dropout_probability = jax.random.normal(jax.random.PRNGKey(0), [])
439
+ if not deterministic and (dropout_probability < self.layerdrop): # skip the layer
440
+ layer_outputs = (None, None)
441
+ else:
442
+ layer_outputs = encoder_layer(
443
+ hidden_states,
444
+ attention_mask,
445
+ output_attentions,
446
+ deterministic,
447
+ )
448
+ hidden_states = layer_outputs[0]
449
+ if output_attentions:
450
+ all_attentions = all_attentions + (layer_outputs[1],)
451
+
452
+ if output_hidden_states:
453
+ all_hidden_states += (hidden_states,)
454
+
455
+ outputs = (hidden_states, all_hidden_states, all_attentions)
456
+
457
+ if not return_dict:
458
+ return tuple(v for v in outputs if v is not None)
459
+
460
+ return FlaxBaseModelOutput(
461
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
462
+ )
463
+
464
+
465
+ class FlaxIndicTransDecoderLayer(nn.Module):
466
+ config: IndicTransConfig
467
+ dtype: jnp.dtype = jnp.float32
468
+
469
+ def setup(self) -> None:
470
+ self.embed_dim = self.config.decoder_embed_dim
471
+ self.self_attn = FlaxIndicTransAttention(
472
+ config=self.config,
473
+ embed_dim=self.embed_dim,
474
+ num_heads=self.config.decoder_attention_heads,
475
+ dropout=self.config.attention_dropout,
476
+ causal=True,
477
+ dtype=self.dtype,
478
+ )
479
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
480
+ self.activation_fn = ACT2FN[self.config.activation_function]
481
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
482
+
483
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
484
+ self.encoder_attn = FlaxIndicTransAttention(
485
+ config=self.config,
486
+ embed_dim=self.embed_dim,
487
+ num_heads=self.config.decoder_attention_heads,
488
+ dropout=self.config.attention_dropout,
489
+ causal=False,
490
+ dtype=self.dtype,
491
+ )
492
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
493
+ self.fc1 = nn.Dense(
494
+ self.config.decoder_ffn_dim,
495
+ dtype=self.dtype,
496
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
497
+ )
498
+ self.fc2 = nn.Dense(
499
+ self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
500
+ )
501
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
502
+ self.normalize_before = self.config.decoder_normalize_before
503
+
504
+ def __call__(
505
+ self,
506
+ hidden_states: jnp.ndarray,
507
+ attention_mask: jnp.ndarray,
508
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
509
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
510
+ init_cache: bool = False,
511
+ output_attentions: bool = True,
512
+ deterministic: bool = True,
513
+ ) -> Tuple[jnp.ndarray]:
514
+ residual = hidden_states
515
+ if self.normalize_before:
516
+ hidden_states = self.self_attn_layer_norm(hidden_states)
517
+
518
+ # Self Attention
519
+ hidden_states, self_attn_weights = self.self_attn(
520
+ hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
521
+ )
522
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
523
+ hidden_states = residual + hidden_states
524
+ if not self.normalize_before:
525
+ hidden_states = self.self_attn_layer_norm(hidden_states)
526
+
527
+ # Cross-Attention Block
528
+ cross_attn_weights = None
529
+ if encoder_hidden_states is not None:
530
+ residual = hidden_states
531
+ if self.normalize_before:
532
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
533
+
534
+ hidden_states, cross_attn_weights = self.encoder_attn(
535
+ hidden_states=hidden_states,
536
+ key_value_states=encoder_hidden_states,
537
+ attention_mask=encoder_attention_mask,
538
+ # init_cache=init_cache
539
+ )
540
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
541
+ hidden_states = residual + hidden_states
542
+ if not self.normalize_before:
543
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
544
+
545
+ # Fully Connected
546
+ residual = hidden_states
547
+ if self.normalize_before:
548
+ hidden_states = self.final_layer_norm(hidden_states)
549
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
550
+ hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
551
+ hidden_states = self.fc2(hidden_states)
552
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
553
+ hidden_states = residual + hidden_states
554
+ if not self.normalize_before:
555
+ hidden_states = self.final_layer_norm(hidden_states)
556
+
557
+ outputs = (hidden_states,)
558
+
559
+ if output_attentions:
560
+ outputs += (self_attn_weights, cross_attn_weights)
561
+
562
+ return outputs
563
+
564
+
565
+ class FlaxIndicTransDecoderLayerCollection(nn.Module):
566
+ config: IndicTransConfig
567
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
568
+
569
+ def setup(self):
570
+ self.layers = [
571
+ FlaxIndicTransDecoderLayer(self.config, name=str(i), dtype=self.dtype)
572
+ for i in range(self.config.decoder_layers)
573
+ ]
574
+ self.layerdrop = self.config.decoder_layerdrop
575
+
576
+ def __call__(
577
+ self,
578
+ hidden_states,
579
+ attention_mask,
580
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
581
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
582
+ deterministic: bool = True,
583
+ init_cache: bool = False,
584
+ output_attentions: bool = False,
585
+ output_hidden_states: bool = False,
586
+ return_dict: bool = True,
587
+ ):
588
+ # decoder layers
589
+ all_hidden_states = () if output_hidden_states else None
590
+ all_self_attns = () if output_attentions else None
591
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
592
+
593
+ for decoder_layer in self.layers:
594
+ if output_hidden_states:
595
+ all_hidden_states += (hidden_states,)
596
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
597
+ dropout_probability = jax.random.normal(jax.random.PRNGKey(0), [])
598
+ if not deterministic and (dropout_probability < self.layerdrop):
599
+ layer_outputs = (None, None, None)
600
+ else:
601
+ layer_outputs = decoder_layer(
602
+ hidden_states,
603
+ attention_mask=attention_mask,
604
+ encoder_hidden_states=encoder_hidden_states,
605
+ encoder_attention_mask=encoder_attention_mask,
606
+ init_cache=init_cache,
607
+ output_attentions=output_attentions,
608
+ deterministic=deterministic,
609
+ )
610
+
611
+ hidden_states = layer_outputs[0]
612
+ if output_attentions:
613
+ all_self_attns += (layer_outputs[1],)
614
+
615
+ if encoder_hidden_states is not None:
616
+ all_cross_attentions += (layer_outputs[2],)
617
+
618
+ # add hidden states from the last decoder layer
619
+ if output_hidden_states:
620
+ all_hidden_states += (hidden_states,)
621
+
622
+ outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
623
+
624
+ if not return_dict:
625
+ return tuple(v for v in outputs if v is not None)
626
+
627
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
628
+ last_hidden_state=hidden_states,
629
+ hidden_states=all_hidden_states,
630
+ attentions=all_self_attns,
631
+ cross_attentions=all_cross_attentions,
632
+ )
633
+
634
+ class FlaxIndicTransEncoder(nn.Module):
635
+ config: IndicTransConfig
636
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
637
+
638
+ def setup(self):
639
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
640
+
641
+ embed_dim = self.config.encoder_embed_dim
642
+ self.padding_idx = self.config.pad_token_id
643
+ self.max_source_positions = self.config.max_source_positions
644
+ self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
645
+
646
+ self.embed_tokens = nn.Embed(
647
+ self.config.encoder_vocab_size,
648
+ embed_dim,
649
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
650
+ )
651
+
652
+ self.embed_positions = FlaxIndicTransSinusoidalPositionalEmbedding(
653
+ self.config.max_source_positions,
654
+ embed_dim,
655
+ self.padding_idx,
656
+ )
657
+ self.layers = FlaxIndicTransEncoderLayerCollection(self.config, self.dtype)
658
+ self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.encoder_normalize_before else None
659
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.layernorm_embedding else None
660
+
661
+ def __call__(
662
+ self,
663
+ input_ids,
664
+ attention_mask,
665
+ position_ids,
666
+ output_attentions: bool = False,
667
+ output_hidden_states: bool = False,
668
+ return_dict: bool = True,
669
+ deterministic: bool = True,
670
+ ):
671
+ input_shape = input_ids.shape
672
+ input_ids = input_ids.reshape(-1, input_shape[-1])
673
+
674
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
675
+
676
+ embed_pos = self.embed_positions(input_ids, inputs_embeds)
677
+
678
+ hidden_states = inputs_embeds + embed_pos
679
+ if self.layernorm_embedding is not None:
680
+ hidden_states = self.layernorm_embedding(hidden_states)
681
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
682
+
683
+ outputs = self.layers(
684
+ hidden_states,
685
+ attention_mask,
686
+ deterministic=deterministic,
687
+ output_attentions=output_attentions,
688
+ output_hidden_states=output_hidden_states,
689
+ return_dict=return_dict,
690
+ )
691
+
692
+ last_hidden_states = outputs[0]
693
+
694
+ if self.layer_norm is not None:
695
+ last_hidden_states = self.layer_norm(last_hidden_states)
696
+
697
+ # update the last element in `hidden_states` after applying `layernorm` above
698
+ hidden_states = None
699
+ if output_hidden_states:
700
+ hidden_states = outputs[1]
701
+ hidden_states = hidden_states[:-1] + (last_hidden_states,)
702
+
703
+ if not return_dict:
704
+ outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
705
+ return tuple(v for v in outputs if v is not None)
706
+
707
+ return FlaxBaseModelOutput(
708
+ last_hidden_state=last_hidden_states,
709
+ hidden_states=hidden_states,
710
+ attentions=outputs.attentions,
711
+ )
712
+
713
+
714
+ class FlaxIndicTransDecoder(nn.Module):
715
+ config: IndicTransConfig
716
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
717
+
718
+ def setup(self):
719
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
720
+
721
+ embed_dim = self.config.encoder_embed_dim
722
+ self.padding_idx = self.config.pad_token_id
723
+ self.max_target_positions = self.config.max_target_positions
724
+ self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
725
+
726
+ self.embed_tokens = nn.Embed(
727
+ self.config.decoder_vocab_size,
728
+ embed_dim,
729
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
730
+ )
731
+
732
+ self.embed_positions = FlaxIndicTransSinusoidalPositionalEmbedding(
733
+ self.config.max_target_positions,
734
+ embed_dim,
735
+ self.padding_idx,
736
+ )
737
+
738
+ self.layers = FlaxIndicTransDecoderLayerCollection(self.config, self.dtype)
739
+ self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.decoder_normalize_before else None
740
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) if self.config.layernorm_embedding else None
741
+
742
+
743
+ def __call__(
744
+ self,
745
+ input_ids,
746
+ attention_mask,
747
+ position_ids,
748
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
749
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
750
+ init_cache: bool = False,
751
+ output_attentions: bool = False,
752
+ output_hidden_states: bool = False,
753
+ return_dict: bool = True,
754
+ deterministic: bool = True,
755
+ ):
756
+
757
+ input_shape = input_ids.shape
758
+ input_ids = input_ids.reshape(-1, input_shape[-1])
759
+
760
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
761
+
762
+ # embed positions
763
+ positions = self.embed_positions(input_ids, inputs_embeds)
764
+
765
+ hidden_states = inputs_embeds + positions
766
+
767
+ if self.layernorm_embedding is not None:
768
+ hidden_states = self.layernorm_embedding(hidden_states)
769
+
770
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
771
+
772
+ outputs = self.layers(
773
+ hidden_states,
774
+ attention_mask,
775
+ encoder_hidden_states,
776
+ encoder_attention_mask,
777
+ deterministic=deterministic,
778
+ init_cache=init_cache,
779
+ output_attentions=output_attentions,
780
+ output_hidden_states=output_hidden_states,
781
+ return_dict=return_dict,
782
+ )
783
+
784
+ last_hidden_states = outputs[0]
785
+
786
+ if self.layer_norm is not None:
787
+ last_hidden_states = self.layer_norm(last_hidden_states)
788
+
789
+ # update the last element in `hidden_states` after applying `layernorm` above
790
+ hidden_states = None
791
+ if output_hidden_states:
792
+ hidden_states = outputs[1]
793
+ hidden_states = hidden_states[:-1] + (last_hidden_states,)
794
+
795
+ if not return_dict:
796
+ outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
797
+ return tuple(v for v in outputs if v is not None)
798
+
799
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
800
+ last_hidden_state=last_hidden_states,
801
+ hidden_states=hidden_states,
802
+ attentions=outputs.attentions,
803
+ cross_attentions=outputs.cross_attentions,
804
+ )
805
+
806
+
807
+ class FlaxIndicTransModule(nn.Module):
808
+ config: IndicTransConfig
809
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
810
+
811
+ def setup(self):
812
+ self.encoder = FlaxIndicTransEncoder(self.config, dtype=self.dtype)
813
+ self.decoder = FlaxIndicTransDecoder(self.config, dtype=self.dtype)
814
+
815
+ def _get_encoder_module(self):
816
+ return self.encoder
817
+
818
+ def _get_decoder_module(self):
819
+ return self.decoder
820
+
821
+ def __call__(
822
+ self,
823
+ input_ids,
824
+ attention_mask,
825
+ decoder_input_ids,
826
+ decoder_attention_mask,
827
+ position_ids,
828
+ decoder_position_ids,
829
+ output_attentions: bool = False,
830
+ output_hidden_states: bool = False,
831
+ return_dict: bool = True,
832
+ deterministic: bool = True,
833
+ ):
834
+ encoder_outputs = self.encoder(
835
+ input_ids=input_ids,
836
+ attention_mask=attention_mask,
837
+ position_ids=position_ids,
838
+ output_attentions=output_attentions,
839
+ output_hidden_states=output_hidden_states,
840
+ return_dict=return_dict,
841
+ deterministic=deterministic,
842
+ )
843
+
844
+ decoder_outputs = self.decoder(
845
+ input_ids=decoder_input_ids,
846
+ attention_mask=decoder_attention_mask,
847
+ position_ids=decoder_position_ids,
848
+ encoder_hidden_states=encoder_outputs[0],
849
+ encoder_attention_mask=attention_mask,
850
+ output_attentions=output_attentions,
851
+ output_hidden_states=output_hidden_states,
852
+ return_dict=return_dict,
853
+ deterministic=deterministic,
854
+ )
855
+
856
+ if not return_dict:
857
+ return decoder_outputs + encoder_outputs
858
+
859
+ return FlaxSeq2SeqModelOutput(
860
+ last_hidden_state=decoder_outputs.last_hidden_state,
861
+ decoder_hidden_states=decoder_outputs.hidden_states,
862
+ decoder_attentions=decoder_outputs.attentions,
863
+ cross_attentions=decoder_outputs.cross_attentions,
864
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
865
+ encoder_hidden_states=encoder_outputs.hidden_states,
866
+ encoder_attentions=encoder_outputs.attentions,
867
+ )
868
+
869
+
870
+ class FlaxIndicTransPreTrainedModel(FlaxPreTrainedModel):
871
+ config_class = IndicTransConfig
872
+ base_model_prefix: str = "model"
873
+ module_class: nn.Module = None
874
+
875
+ def __init__(
876
+ self,
877
+ config: IndicTransConfig,
878
+ input_shape: Tuple[int] = (1, 1),
879
+ seed: int = 0,
880
+ dtype: jnp.dtype = jnp.float32,
881
+ _do_init: bool = True,
882
+ **kwargs,
883
+ ):
884
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
885
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
886
+
887
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
888
+ # init input tensors
889
+ input_ids = jnp.zeros(input_shape, dtype="i4")
890
+ # make sure initialization pass will work for FlaxMBartForSequenceClassificationModule
891
+ input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
892
+ attention_mask = jnp.ones_like(input_ids)
893
+ decoder_input_ids = input_ids
894
+ decoder_attention_mask = jnp.ones_like(input_ids)
895
+
896
+ batch_size, sequence_length = input_ids.shape
897
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
898
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
899
+
900
+ params_rng, dropout_rng = jax.random.split(rng)
901
+ rngs = {"params": params_rng, "dropout": dropout_rng}
902
+
903
+ random_params = self.module.init(
904
+ rngs,
905
+ input_ids,
906
+ attention_mask,
907
+ decoder_input_ids,
908
+ decoder_attention_mask,
909
+ position_ids,
910
+ decoder_position_ids,
911
+ )["params"]
912
+
913
+ if params is not None:
914
+ random_params = flatten_dict(unfreeze(random_params))
915
+ params = flatten_dict(unfreeze(params))
916
+ for missing_key in self._missing_keys:
917
+ params[missing_key] = random_params[missing_key]
918
+ self._missing_keys = set()
919
+ return freeze(unflatten_dict(params))
920
+ else:
921
+ return random_params
922
+
923
+ def init_cache(self, batch_size, max_length, encoder_outputs):
924
+ # init input variables to retrieve cache
925
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
926
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
927
+ decoder_position_ids = jnp.broadcast_to(
928
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
929
+ )
930
+
931
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
932
+ decoder_module = module._get_decoder_module()
933
+ return decoder_module(
934
+ decoder_input_ids,
935
+ decoder_attention_mask,
936
+ decoder_position_ids,
937
+ **kwargs,
938
+ )
939
+
940
+ init_variables = self.module.init(
941
+ jax.random.PRNGKey(0),
942
+ decoder_input_ids=decoder_input_ids,
943
+ decoder_attention_mask=decoder_attention_mask,
944
+ decoder_position_ids=decoder_position_ids,
945
+ encoder_hidden_states=encoder_outputs[0],
946
+ init_cache=True,
947
+ method=_decoder_forward, # we only need to call the decoder to init the cache
948
+ )
949
+ return unfreeze(init_variables["cache"])
950
+
951
+ def encode(
952
+ self,
953
+ input_ids: jnp.ndarray,
954
+ attention_mask: Optional[jnp.ndarray] = None,
955
+ position_ids: Optional[jnp.ndarray] = None,
956
+ output_attentions: Optional[bool] = None,
957
+ output_hidden_states: Optional[bool] = None,
958
+ return_dict: Optional[bool] = None,
959
+ train: bool = False,
960
+ params: dict = None,
961
+ dropout_rng: PRNGKey = None,
962
+ ):
963
+
964
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
965
+ output_hidden_states = (
966
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
967
+ )
968
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
969
+
970
+ if attention_mask is None:
971
+ attention_mask = jnp.ones_like(input_ids)
972
+ if position_ids is None:
973
+ batch_size, sequence_length = input_ids.shape
974
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
975
+
976
+ # Handle any PRNG if needed
977
+ rngs = {}
978
+ if dropout_rng is not None:
979
+ rngs["dropout"] = dropout_rng
980
+
981
+ def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
982
+ encode_module = module._get_encoder_module()
983
+ return encode_module(input_ids, attention_mask, position_ids, **kwargs)
984
+
985
+ return self.module.apply(
986
+ {"params": params or self.params},
987
+ input_ids=jnp.array(input_ids, dtype="i4"),
988
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
989
+ position_ids=jnp.array(position_ids, dtype="i4"),
990
+ output_attentions=output_attentions,
991
+ output_hidden_states=output_hidden_states,
992
+ return_dict=return_dict,
993
+ deterministic=not train,
994
+ rngs=rngs,
995
+ method=_encoder_forward,
996
+ )
997
+
998
+ def decode(
999
+ self,
1000
+ decoder_input_ids,
1001
+ encoder_outputs,
1002
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1003
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1004
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1005
+ past_key_values: dict = None,
1006
+ output_attentions: Optional[bool] = None,
1007
+ output_hidden_states: Optional[bool] = None,
1008
+ return_dict: Optional[bool] = None,
1009
+ train: bool = False,
1010
+ params: dict = None,
1011
+ dropout_rng: PRNGKey = None,
1012
+ ):
1013
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1014
+ output_hidden_states = (
1015
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1016
+ )
1017
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1018
+
1019
+ encoder_hidden_states = encoder_outputs[0]
1020
+ if encoder_attention_mask is None:
1021
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
1022
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
1023
+
1024
+ batch_size, sequence_length = decoder_input_ids.shape
1025
+ if decoder_attention_mask is None:
1026
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
1027
+
1028
+ if decoder_position_ids is None:
1029
+ if past_key_values is not None:
1030
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1031
+
1032
+ decoder_position_ids = jnp.broadcast_to(
1033
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1034
+ )
1035
+
1036
+ # Handle any PRNG if needed
1037
+ rngs = {}
1038
+ if dropout_rng is not None:
1039
+ rngs["dropout"] = dropout_rng
1040
+
1041
+ inputs = {"params": params or self.params}
1042
+
1043
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1044
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1045
+ # it can be changed by FlaxMBartAttention module
1046
+ if past_key_values:
1047
+ inputs["cache"] = past_key_values
1048
+ mutable = ["cache"]
1049
+ else:
1050
+ mutable = False
1051
+
1052
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
1053
+ decoder_module = module._get_decoder_module()
1054
+ return decoder_module(
1055
+ decoder_input_ids,
1056
+ decoder_attention_mask,
1057
+ decoder_position_ids,
1058
+ **kwargs,
1059
+ )
1060
+
1061
+ outputs = self.module.apply(
1062
+ inputs,
1063
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1064
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1065
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1066
+ encoder_hidden_states=encoder_hidden_states,
1067
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
1068
+ output_attentions=output_attentions,
1069
+ output_hidden_states=output_hidden_states,
1070
+ return_dict=return_dict,
1071
+ deterministic=not train,
1072
+ rngs=rngs,
1073
+ mutable=mutable,
1074
+ method=_decoder_forward,
1075
+ )
1076
+
1077
+ # add updated cache to model output
1078
+ if past_key_values is not None and return_dict:
1079
+ outputs, past = outputs
1080
+ outputs["past_key_values"] = unfreeze(past["cache"])
1081
+ return outputs
1082
+ elif past_key_values is not None and not return_dict:
1083
+ outputs, past = outputs
1084
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1085
+
1086
+ return outputs
1087
+
1088
+ def __call__(
1089
+ self,
1090
+ input_ids: jnp.ndarray,
1091
+ attention_mask: Optional[jnp.ndarray] = None,
1092
+ decoder_input_ids: Optional[jnp.ndarray] = None,
1093
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1094
+ position_ids: Optional[jnp.ndarray] = None,
1095
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1096
+ output_attentions: Optional[bool] = None,
1097
+ output_hidden_states: Optional[bool] = None,
1098
+ return_dict: Optional[bool] = None,
1099
+ train: bool = False,
1100
+ params: dict = None,
1101
+ dropout_rng: PRNGKey = None,
1102
+ ):
1103
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1104
+ output_hidden_states = (
1105
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1106
+ )
1107
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1108
+
1109
+ # prepare encoder inputs
1110
+ if attention_mask is None:
1111
+ attention_mask = jnp.ones_like(input_ids)
1112
+ if position_ids is None:
1113
+ batch_size, sequence_length = input_ids.shape
1114
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
1115
+
1116
+ # prepare decoder inputs
1117
+ if decoder_input_ids is None:
1118
+ decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id, self.config.decoder_start_token_id)
1119
+ if decoder_attention_mask is None:
1120
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
1121
+ if decoder_position_ids is None:
1122
+ batch_size, sequence_length = decoder_input_ids.shape
1123
+ decoder_position_ids = jnp.broadcast_to(
1124
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1125
+ )
1126
+
1127
+ # Handle any PRNG if needed
1128
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
1129
+
1130
+ return self.module.apply(
1131
+ {"params": params or self.params},
1132
+ input_ids=jnp.array(input_ids, dtype="i4"),
1133
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
1134
+ position_ids=jnp.array(position_ids, dtype="i4"),
1135
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1136
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1137
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1138
+ output_attentions=output_attentions,
1139
+ output_hidden_states=output_hidden_states,
1140
+ return_dict=return_dict,
1141
+ deterministic=not train,
1142
+ rngs=rngs,
1143
+ )
1144
+
1145
+
1146
+ class FlaxIndicTransModel(FlaxIndicTransPreTrainedModel):
1147
+ config: IndicTransConfig
1148
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1149
+ module_class = FlaxIndicTransModule
1150
+
1151
+
1152
+ class FlaxIndicTransForConditionalGenerationModule(nn.Module):
1153
+ config: IndicTransConfig
1154
+ dtype: jnp.dtype = jnp.float32
1155
+ bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
1156
+
1157
+ def setup(self):
1158
+ self.model = FlaxIndicTransModule(config=self.config, dtype=self.dtype)
1159
+
1160
+ self.lm_head = nn.Dense(
1161
+ self.config.decoder_vocab_size,
1162
+ use_bias=False,
1163
+ dtype=self.dtype,
1164
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
1165
+ )
1166
+
1167
+ def _get_encoder_module(self):
1168
+ return self.model.encoder
1169
+
1170
+ def _get_decoder_module(self):
1171
+ return self.model.decoder
1172
+
1173
+ def __call__(
1174
+ self,
1175
+ input_ids,
1176
+ attention_mask,
1177
+ decoder_input_ids,
1178
+ decoder_attention_mask,
1179
+ position_ids,
1180
+ decoder_position_ids,
1181
+ output_attentions: bool = False,
1182
+ output_hidden_states: bool = False,
1183
+ return_dict: bool = True,
1184
+ deterministic: bool = True,
1185
+ ):
1186
+ outputs = self.model(
1187
+ input_ids=input_ids,
1188
+ attention_mask=attention_mask,
1189
+ decoder_input_ids=decoder_input_ids,
1190
+ decoder_attention_mask=decoder_attention_mask,
1191
+ position_ids=position_ids,
1192
+ decoder_position_ids=decoder_position_ids,
1193
+ output_attentions=output_attentions,
1194
+ output_hidden_states=output_hidden_states,
1195
+ return_dict=return_dict,
1196
+ deterministic=deterministic,
1197
+ )
1198
+
1199
+ hidden_states = outputs[0]
1200
+
1201
+ if self.config.share_decoder_input_output_embed:
1202
+ shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
1203
+ lm_logits = jax.lax.stop_gradient(self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states))
1204
+ else:
1205
+ lm_logits = jax.lax.stop_gradient(self.lm_head(hidden_states))
1206
+
1207
+ if not return_dict:
1208
+ output = (lm_logits,) + outputs[1:]
1209
+ return output
1210
+
1211
+ return FlaxSeq2SeqLMOutput(
1212
+ logits=lm_logits,
1213
+ decoder_hidden_states=outputs.decoder_hidden_states,
1214
+ decoder_attentions=outputs.decoder_attentions,
1215
+ cross_attentions=outputs.cross_attentions,
1216
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1217
+ encoder_hidden_states=outputs.encoder_hidden_states,
1218
+ encoder_attentions=outputs.encoder_attentions,
1219
+ )
1220
+
1221
+
1222
+ class FlaxIndicTransForConditionalGeneration(FlaxIndicTransPreTrainedModel):
1223
+ module_class = FlaxIndicTransForConditionalGenerationModule
1224
+ dtype: jnp.dtype = jnp.float32
1225
+
1226
+ def decode(
1227
+ self,
1228
+ decoder_input_ids,
1229
+ encoder_outputs,
1230
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1231
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1232
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1233
+ past_key_values: dict = None,
1234
+ output_attentions: Optional[bool] = None,
1235
+ output_hidden_states: Optional[bool] = None,
1236
+ return_dict: Optional[bool] = None,
1237
+ train: bool = False,
1238
+ params: dict = None,
1239
+ dropout_rng: PRNGKey = None,
1240
+ ):
1241
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1242
+ output_hidden_states = (
1243
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1244
+ )
1245
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1246
+
1247
+ encoder_hidden_states = encoder_outputs[0]
1248
+ if encoder_attention_mask is None:
1249
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
1250
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
1251
+
1252
+ batch_size, sequence_length = decoder_input_ids.shape
1253
+ if decoder_attention_mask is None:
1254
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
1255
+
1256
+ if decoder_position_ids is None:
1257
+ if past_key_values is not None:
1258
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
1259
+
1260
+ decoder_position_ids = jnp.broadcast_to(
1261
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1262
+ )
1263
+
1264
+ # Handle any PRNG if needed
1265
+ rngs = {}
1266
+ if dropout_rng is not None:
1267
+ rngs["dropout"] = dropout_rng
1268
+
1269
+ inputs = {"params": params or self.params}
1270
+
1271
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1272
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1273
+ # it can be changed by FlaxMBartAttention module
1274
+ if past_key_values:
1275
+ inputs["cache"] = past_key_values
1276
+ mutable = ["cache"]
1277
+ else:
1278
+ mutable = False
1279
+
1280
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
1281
+ decoder_module = module._get_decoder_module()
1282
+ outputs = decoder_module(
1283
+ decoder_input_ids,
1284
+ decoder_attention_mask,
1285
+ decoder_position_ids,
1286
+ **kwargs,
1287
+ )
1288
+ hidden_states = outputs[0]
1289
+
1290
+ if self.config.share_decoder_input_output_embed:
1291
+ shared_embedding = module.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
1292
+ lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
1293
+ else:
1294
+ lm_logits = module.lm_head(hidden_states)
1295
+
1296
+ return lm_logits, outputs
1297
+
1298
+ outputs = self.module.apply(
1299
+ inputs,
1300
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1301
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1302
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1303
+ encoder_hidden_states=encoder_hidden_states,
1304
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
1305
+ output_attentions=output_attentions,
1306
+ output_hidden_states=output_hidden_states,
1307
+ return_dict=return_dict,
1308
+ deterministic=not train,
1309
+ rngs=rngs,
1310
+ mutable=mutable,
1311
+ method=_decoder_forward,
1312
+ )
1313
+
1314
+ if past_key_values is None:
1315
+ lm_logits, decoder_outputs = outputs
1316
+ else:
1317
+ (lm_logits, decoder_outputs), past = outputs
1318
+
1319
+ if return_dict:
1320
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
1321
+ logits=lm_logits,
1322
+ hidden_states=decoder_outputs.hidden_states,
1323
+ attentions=decoder_outputs.attentions,
1324
+ cross_attentions=decoder_outputs.cross_attentions,
1325
+ )
1326
+ else:
1327
+ outputs = (lm_logits,) + decoder_outputs[1:]
1328
+
1329
+ # add updated cache to model output
1330
+ if past_key_values is not None and return_dict:
1331
+ outputs["past_key_values"] = unfreeze(past["cache"])
1332
+ return outputs
1333
+ elif past_key_values is not None and not return_dict:
1334
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1335
+
1336
+ return outputs
1337
+
1338
+ def prepare_inputs_for_generation(
1339
+ self,
1340
+ decoder_input_ids,
1341
+ max_length,
1342
+ attention_mask: Optional[jax.Array] = None,
1343
+ decoder_attention_mask: Optional[jax.Array] = None,
1344
+ encoder_outputs=None,
1345
+ **kwargs,
1346
+ ):
1347
+ # initializing the cache
1348
+ batch_size, seq_length = decoder_input_ids.shape
1349
+
1350
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
1351
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1352
+ # But since the decoder uses a causal mask, those positions are masked anyways.
1353
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1354
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1355
+ if decoder_attention_mask is not None:
1356
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
1357
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
1358
+ else:
1359
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1360
+
1361
+ return {
1362
+ "past_key_values": past_key_values,
1363
+ "encoder_outputs": encoder_outputs,
1364
+ "encoder_attention_mask": attention_mask,
1365
+ "decoder_attention_mask": extended_attention_mask,
1366
+ "decoder_position_ids": position_ids,
1367
+ }
1368
+
1369
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1370
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1371
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
1372
+ return model_kwargs
1373
+