File size: 21,672 Bytes
b00cdfe
 
a244e91
 
 
 
 
 
 
 
 
 
 
 
3ed2a5d
 
 
 
a244e91
3ed2a5d
a244e91
 
 
 
 
 
 
 
3ed2a5d
a244e91
 
 
 
 
b00cdfe
 
a244e91
 
 
 
 
 
 
 
3ed2a5d
 
 
 
 
 
 
 
 
 
a244e91
 
 
 
 
 
 
 
 
 
3ed2a5d
 
 
a244e91
3ed2a5d
a244e91
 
 
 
 
 
 
 
 
 
 
bfef308
 
a244e91
 
 
 
 
 
3ed2a5d
a244e91
3ed2a5d
 
 
 
a244e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ed2a5d
 
 
a244e91
 
 
 
 
 
 
 
3ed2a5d
 
 
a244e91
 
 
 
 
 
 
 
 
 
3ed2a5d
a244e91
 
 
 
 
 
 
 
 
 
 
 
 
 
b00cdfe
a244e91
 
 
 
3ed2a5d
 
a244e91
 
 
3ed2a5d
a244e91
3ed2a5d
 
 
 
 
b00cdfe
3ed2a5d
 
 
 
a244e91
 
 
 
 
 
 
 
3ed2a5d
 
 
a244e91
 
 
3ed2a5d
 
 
 
 
a244e91
 
3ed2a5d
a244e91
 
3ed2a5d
 
 
a244e91
 
 
 
 
3ed2a5d
 
 
a244e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b00cdfe
a244e91
b00cdfe
a244e91
b00cdfe
a244e91
3ed2a5d
a244e91
 
 
 
 
 
 
 
 
 
 
 
 
89ca80d
a244e91
 
 
 
 
 
 
 
 
 
3ed2a5d
a244e91
 
3ed2a5d
 
a244e91
 
 
 
 
 
 
 
 
 
b00cdfe
a244e91
 
b00cdfe
a244e91
b00cdfe
a244e91
 
 
 
 
 
3ed2a5d
 
 
a244e91
3ed2a5d
a244e91
 
 
 
 
03d8c80
a244e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ed2a5d
a244e91
 
3ed2a5d
 
 
a244e91
 
 
 
 
3ed2a5d
 
 
a244e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ed2a5d
 
 
a244e91
 
 
 
 
 
 
3ed2a5d
 
a244e91
3ed2a5d
a244e91
3ed2a5d
a244e91
3ed2a5d
a244e91
 
 
3ed2a5d
 
 
 
 
 
 
a244e91
 
 
 
 
 
 
 
 
4bf5c74
3ed2a5d
 
 
a244e91
 
 
 
 
 
 
 
 
 
 
 
 
 
3ed2a5d
a244e91
 
3ed2a5d
 
a244e91
 
 
 
7d3b1a0
a244e91
 
 
 
3ed2a5d
 
 
 
 
 
 
 
 
 
7d3b1a0
3ed2a5d
 
a244e91
 
 
 
3ed2a5d
a244e91
 
3ed2a5d
a244e91
 
 
 
3ed2a5d
a244e91
 
 
 
 
 
3ed2a5d
 
 
a244e91
 
 
 
 
 
 
 
3ed2a5d
 
 
a244e91
 
 
 
3ed2a5d
a244e91
 
 
b00cdfe
a244e91
b00cdfe
 
a244e91
 
 
 
b00cdfe
 
 
 
a244e91
 
b00cdfe
 
 
 
a244e91
 
b00cdfe
 
 
 
 
a244e91
b88266d
 
b00cdfe
 
 
 
 
 
a244e91
b00cdfe
 
a244e91
b00cdfe
 
 
a244e91
b00cdfe
 
 
a244e91
 
b8c22f0
b00cdfe
a244e91
b00cdfe
 
 
 
 
54ece9e
b8c22f0
 
b00cdfe
a244e91
b00cdfe
a244e91
b00cdfe
 
 
a244e91
 
 
 
b00cdfe
e30ab96
a244e91
 
 
 
b00cdfe
 
a244e91
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
import os
from typing import Callable, Optional, Tuple, Union

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from jax import lax
from jax.random import PRNGKey
from transformers.modeling_flax_outputs import (
    FlaxCausalLMOutputWithCrossAttentions,
    FlaxSeq2SeqLMOutput,
    FlaxSeq2SeqModelOutput,
)
from .configuration_vit_gpt2 import ViTGPT2Config
from transformers import ViTConfig, GPT2Config
from transformers import FlaxPreTrainedModel, FlaxViTModel
from transformers.models.vit.modeling_flax_vit import FlaxViTModule
from .modeling_flax_gpt2 import (
    FlaxGPT2PreTrainedModel,
    FlaxGPT2Module,
    FlaxGPT2Model,
    FlaxGPT2LMHeadModule,
    FlaxGPT2LMHeadModel,
)


class FlaxViTGPT2LMModule(nn.Module):
    """Play the same role as ``FlaxBartModule`` but with the decoder equipped with a LM head."""
    config: ViTGPT2Config
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):

        self.encoder = FlaxViTModule(self.config.vision_config, dtype=self.dtype)
        self.decoder = FlaxGPT2LMHeadModule(self.config.text_config, dtype=self.dtype)

    def _get_encoder_module(self):
        return self.encoder

    def _get_decoder_module(self):
        return self.decoder

    def __call__(
        self,
        pixel_values,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        encoder_outputs = self.encoder(
            pixel_values=pixel_values,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return FlaxSeq2SeqLMOutput(
            logits=decoder_outputs.logits,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module):
    """Play the same role as ``FlaxBartForConditionalGenerationModule`` but with the decoder equipped with a LM head.

       Actually, it is identical to ``FlaxBartForConditionalGenerationModule`` with a different name.
    """
    config: ViTGPT2Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.model = FlaxViTGPT2LMModule(config=self.config, dtype=self.dtype)

    def _get_encoder_module(self):
        return self.model.encoder

    def _get_decoder_module(self):
        return self.model.decoder

    def __call__(
        self,
        pixel_values,
        attention_mask,
        decoder_input_ids,
        decoder_attention_mask,
        decoder_position_ids,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        deterministic: bool = True,
    ):
        outputs = self.model(
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=deterministic,
        )

        return outputs


class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
    """Play the same role as ``FlaxBartPretrainedModel``"""
    config_class = ViTGPT2Config
    base_model_prefix: str = "model"
    module_class: nn.Module = None

    def __init__(
        self,
        config: ViTGPT2Config,
        input_shape: Tuple = None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        **kwargs,
    ):
        if input_shape is None:
            input_shape = (
                (1, config.vision_config.image_size, config.vision_config.image_size, 3),
                (1, 1),
            )

        module = self.module_class(config=config, dtype=dtype, **kwargs)
        # This will use ``self.init_weights``.
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:

        encoder_input_shape, decoder_input_shape = input_shape

        # init input tensors
        pixel_values = jax.random.normal(rng, encoder_input_shape)
        attention_mask = None
        decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
        # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
        decoder_input_ids = jax.ops.index_update(decoder_input_ids, (..., -1), self.config.text_config.eos_token_id)
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)

        batch_size, sequence_length = decoder_input_ids.shape
        decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        return self.module.init(
            rngs,
            pixel_values,
            attention_mask,
            decoder_input_ids,
            decoder_attention_mask,
            decoder_position_ids,
        )["params"]

    def init_cache(self, batch_size, max_length, encoder_outputs):
        # init input variables to retrieve cache
        decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
        decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        decoder_position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape,
        )

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
            decoder_module = module._get_decoder_module()
            return decoder_module(
                input_ids=decoder_input_ids,
                attention_mask=decoder_attention_mask,
                position_ids=decoder_position_ids,
                **kwargs,
            )

        init_variables = self.module.init(
            jax.random.PRNGKey(0),
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            decoder_position_ids=decoder_position_ids,
            encoder_hidden_states=encoder_outputs[0],
            init_cache=True,
            method=_decoder_forward,  # we only need to call the decoder to init the cache
        )
        return unfreeze(init_variables["cache"])

    def encode(
        self,
        pixel_values: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        output_attentions = (output_attentions if output_attentions is not None else self.config.vision_config.output_attentions)
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.vision_config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.vision_config.return_dict

        # (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.)
        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        def _encoder_forward(module, pixel_values, **kwargs):
            encode_module = module._get_encoder_module()
            return encode_module(pixel_values, **kwargs)

        return self.module.apply(
            {"params": params or self.params},
            pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            method=_encoder_forward,
        )

    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):

        output_attentions = (
            output_attentions if output_attentions is not None else self.config.text_config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.text_config.return_dict

        encoder_hidden_states = encoder_outputs[0]
        if encoder_attention_mask is None:
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        batch_size, sequence_length = decoder_input_ids.shape
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones((batch_size, sequence_length))

        if decoder_position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `position_ids` when passing `past_key_values`."
                )

            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
        # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
        # it can be changed by FlaxGPT2Attention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
            decoder_module = module._get_decoder_module()
            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                **kwargs,
            )

        outputs = self.module.apply(
            inputs,
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            mutable=mutable,
            method=_decoder_forward,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past = outputs
            outputs["past_key_values"] = unfreeze(past["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past = outputs
            outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]

        return outputs

    def __call__(
        self,
        pixel_values: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        decoder_input_ids: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # prepare encoder inputs (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.)
        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        # prepare decoder inputs
        if decoder_input_ids is None:
            decoder_input_ids = self.config.decoder_start_token_id * jnp.ones((pixel_values.shape[0], 1))
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        if decoder_position_ids is None:
            batch_size, sequence_length = decoder_input_ids.shape
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

        # Handle any PRNG if needed
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        return self.module.apply(
            {"params": params or self.params},
            pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
            attention_mask=attention_mask,
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
        )


class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
    module_class = FlaxViTGPT2LMForConditionalGenerationModule
    dtype: jnp.dtype = jnp.float32

    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):

        return super().decode(
            decoder_input_ids,
            encoder_outputs,
            encoder_attention_mask,
            decoder_attention_mask,
            decoder_position_ids,
            past_key_values,
            output_attentions,
            output_hidden_states,
            return_dict,
            train,
            params,
            dropout_rng,
        )

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        max_length,
        attention_mask: Optional[jnp.DeviceArray] = None,
        decoder_attention_mask: Optional[jnp.DeviceArray] = None,
        encoder_outputs=None,
        **kwargs,
    ):
        # initializing the cache
        batch_size, seq_length = decoder_input_ids.shape

        past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
        # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
        # But since the decoder uses a causal mask, those positions are masked anyways.
        # Thus we can create a single static attention_mask here, which is more efficient for compilation
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
        if decoder_attention_mask is not None:
            position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
        else:
            position_ids = jnp.broadcast_to(
                jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
            )

        return {
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "encoder_attention_mask": attention_mask,
            "decoder_attention_mask": extended_attention_mask,
            "decoder_position_ids": position_ids,
        }

    def update_inputs_for_generation(self, model_outputs, model_kwargs):
        model_kwargs["past_key_values"] = model_outputs.past_key_values
        model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
        return model_kwargs

    @classmethod
    def from_vision_text_pretrained(
        cls,
        vision_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        text_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        *model_args,
        **kwargs,
    ) -> FlaxViTGPT2LMPreTrainedModel:

        vision_kwargs = {
            kwarg[len("vision_"):]: value
            for kwarg, value in kwargs.items()
            if kwarg.startswith("vision_")
        }

        text_kwargs = {
            kwarg[len("text_"):]: value
            for kwarg, value in kwargs.items()
            if kwarg.startswith("text_")
        }

        # remove vit & gpt2 kwargs from kwargs
        for key in vision_kwargs.keys():
            del kwargs["vision_" + key]
        for key in text_kwargs.keys():
            del kwargs["text_" + key]

        vision_model_args = vision_kwargs.pop('model_args', [])
        text_model_args = text_kwargs.pop('model_args', [])

        # Load and initialize the vit & gpt2 model
        vision_model = vision_kwargs.pop("model", None)
        text_model = text_kwargs.pop("model", None)

        if vision_model is None:
            assert (
                vision_pretrained_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `vision_pretrained_model_name_or_path` has to be defined"

            if "config" not in vision_kwargs:
                vision_config = ViTConfig.from_pretrained(vision_pretrained_model_name_or_path)
                vision_kwargs["config"] = vision_config

            # TODO: How to deal with model_args?
            vision_model = FlaxViTModel.from_pretrained(
                vision_pretrained_model_name_or_path, *vision_model_args, **vision_kwargs
            )

        project_encoder = kwargs.pop("project_encoder", None)
        if text_model is None:
            assert (
                text_pretrained_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `text_pretrained_model_name_or_path` has to be defined"

            if "config" not in text_kwargs:
                text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
                text_config.project_encoder = text_kwargs.pop("project_encoder", None)
                if project_encoder is not None:
                    text_config.project_encoder = project_encoder
                text_kwargs["config"] = text_config

            text_kwargs["config"].add_cross_attention = True

            # TODO: How to deal with model_args?
            text_model = FlaxGPT2LMHeadModel.from_pretrained(
                text_pretrained_model_name_or_path, *text_model_args, **text_kwargs
            )

        # instantiate config with corresponding kwargs
        dtype = kwargs.pop("dtype", jnp.float32)
        config = ViTGPT2Config.from_vision_text_configs(
            vision_model.config, text_model.config, project_encoder=project_encoder, **kwargs
        )

        # init model
        model = cls(config, *model_args, dtype=dtype, **kwargs)
        model.params["model"]["encoder"] = vision_model.params
        model.params["model"]["decoder"] = text_model.params

        return model