File size: 36,951 Bytes
2bbf92c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
from typing import Optional, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformers.modeling_flax_outputs import (
    FlaxBaseModelOutputWithPooling,
    FlaxMaskedLMOutput,
    FlaxSequenceClassifierOutput,
)
from transformers.models.bert.modeling_flax_bert import (
    FlaxBertEncoder,
    FlaxBertOnlyMLMHead,
    FlaxBertPooler,
    FlaxPreTrainedModel,
)
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule

from .configuration_clip_vision_bert import CLIPVisionBertConfig


class FlaxCLIPVisionBertEmbeddings(nn.Module):

    config: CLIPVisionBertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation

    def setup(self):
        bert_config = self.config.bert_config
        clip_vision_config = self.config.clip_vision_config

        self.word_embeddings = nn.Embed(
            bert_config.vocab_size,
            bert_config.hidden_size,
            embedding_init=jax.nn.initializers.normal(
                stddev=bert_config.initializer_range
            ),
            dtype=self.dtype,
        )
        self.position_embeddings = nn.Embed(
            bert_config.max_position_embeddings,
            bert_config.hidden_size,
            embedding_init=jax.nn.initializers.normal(
                stddev=bert_config.initializer_range
            ),
            dtype=self.dtype,
        )
        self.token_type_embeddings = nn.Embed(
            bert_config.type_vocab_size,
            bert_config.hidden_size,
            embedding_init=jax.nn.initializers.normal(
                stddev=bert_config.initializer_range
            ),
            dtype=self.dtype,
        )

        self.clip_vision_module = FlaxCLIPVisionModule(
            clip_vision_config, dtype=self.dtype
        )
        self.visual_projection = nn.Dense(
            bert_config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                bert_config.initializer_range, self.dtype
            ),
        )

        self.visual_position_embeddings = nn.Embed(
            bert_config.max_position_embeddings,
            bert_config.hidden_size,
            embedding_init=jax.nn.initializers.normal(
                stddev=bert_config.initializer_range
            ),
            dtype=self.dtype,
        )
        self.visual_token_type_embeddings = nn.Embed(
            bert_config.type_vocab_size,
            bert_config.hidden_size,
            embedding_init=jax.nn.initializers.normal(
                stddev=bert_config.initializer_range
            ),
            dtype=self.dtype,
        )

        self.LayerNorm = nn.LayerNorm(
            epsilon=bert_config.layer_norm_eps, dtype=self.dtype
        )
        self.dropout = nn.Dropout(rate=bert_config.hidden_dropout_prob)

    def __call__(
        self,
        input_ids,
        token_type_ids,
        position_ids,
        pixel_values,
        visual_token_type_ids,
        visual_position_ids,
        deterministic: bool = True,
    ):
        # Embed
        inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
        position_embeds = self.position_embeddings(position_ids.astype("i4"))
        token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))

        # Sum all embeddings
        word_embeddings = inputs_embeds + token_type_embeddings + position_embeds

        # Visual Embed
        visual_inputs_embeds = self.clip_vision_module(pixel_values=pixel_values)[0]
        visual_inputs_embeds = self.visual_projection(visual_inputs_embeds)
        visual_token_type_embeddings = self.visual_token_type_embeddings(
            visual_token_type_ids.astype("i4")
        )
        visual_position_embeds = self.visual_position_embeddings(
            visual_position_ids.astype("i4")
        )

        # Sum all visual embeddings
        visual_embeddings = (
            visual_inputs_embeds + visual_token_type_embeddings + visual_position_embeds
        )

        # Concat
        hidden_states = jnp.concatenate((word_embeddings, visual_embeddings), axis=1)

        # Layer Norm
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.dropout(hidden_states, deterministic=deterministic)
        return hidden_states


class FlaxCLIPVisionBertModule(nn.Module):
    config: CLIPVisionBertConfig
    dtype: jnp.dtype = jnp.float32  # the dtype of the computation
    add_pooling_layer: bool = True

    def setup(self):
        self.embeddings = FlaxCLIPVisionBertEmbeddings(self.config, dtype=self.dtype)
        self.encoder = FlaxBertEncoder(self.config.bert_config, dtype=self.dtype)
        self.pooler = FlaxBertPooler(self.config.bert_config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        pixel_values,
        visual_attention_mask,
        visual_token_type_ids,
        visual_position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        hidden_states = self.embeddings(
            input_ids,
            token_type_ids,
            position_ids,
            pixel_values,
            visual_token_type_ids,
            visual_position_ids,
            deterministic=deterministic,
        )

        combined_attention_mask = jnp.concatenate(
            (attention_mask, visual_attention_mask), axis=1
        )

        outputs = self.encoder(
            hidden_states,
            combined_attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]
        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None

        if not return_dict:
            # if pooled is None, don't return it
            if pooled is None:
                return (hidden_states,) + outputs[1:]
            return (hidden_states, pooled) + outputs[1:]

        return FlaxBaseModelOutputWithPooling(
            last_hidden_state=hidden_states,
            pooler_output=pooled,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class FlaxCLIPVisionBertModel(FlaxPreTrainedModel):
    config_class = CLIPVisionBertConfig
    module_class = FlaxCLIPVisionBertModule

    def __init__(
        self,
        config: CLIPVisionBertConfig,
        input_shape: Tuple = None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        **kwargs,
    ):

        if input_shape is None:
            input_shape = (
                (1, 1),
                (
                    1,
                    config.clip_vision_config.image_size,
                    config.clip_vision_config.image_size,
                    3,
                ),
                (
                    1,
                    (
                        config.clip_vision_config.image_size
                        // config.clip_vision_config.patch_size
                    )
                    ** 2
                    + 1,
                ),
            )

        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(
            config, module, input_shape=input_shape, seed=seed, dtype=dtype
        )

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
        # init input tensors
        textual_input_shape = input_shape[0]
        input_ids = jnp.zeros(textual_input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), textual_input_shape
        )
        attention_mask = jnp.ones_like(input_ids)

        pixel_values = jax.random.normal(rng, input_shape[1])
        visual_attention_mask = jnp.ones(input_shape[2])
        visual_token_type_ids = jnp.ones(input_shape[2])
        visual_position_ids = jnp.broadcast_to(
            jnp.zeros(jnp.atleast_2d(visual_token_type_ids).shape[-1]), input_shape[2]
        )

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        return self.module.init(
            rngs,
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            pixel_values,
            visual_attention_mask,
            visual_token_type_ids,
            visual_position_ids,
            return_dict=False,
        )["params"]

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        pixel_values=None,
        visual_attention_mask=None,
        visual_token_type_ids=None,
        visual_position_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.bert_config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.bert_config.output_hidden_states
        )
        return_dict = (
            return_dict
            if return_dict is not None
            else self.config.bert_config.return_dict
        )

        # pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) # Don't need this for torch permuted input

        visual_sequence_length = (
            pixel_values.shape[0],
            (
                self.config.clip_vision_config.image_size
                // self.config.clip_vision_config.patch_size
            )
            ** 2
            + 1,
        )
        # init input tensors if not passed
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        if position_ids is None:
            position_ids = jnp.broadcast_to(
                jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
            )

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        if visual_token_type_ids is None:
            visual_token_type_ids = jnp.ones(visual_sequence_length)

        if visual_position_ids is None:
            visual_position_ids = jnp.broadcast_to(
                jnp.atleast_2d(visual_token_type_ids).shape[-1], visual_sequence_length
            )

        if visual_attention_mask is None:
            visual_attention_mask = jnp.ones(visual_sequence_length)

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(token_type_ids, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            jnp.array(pixel_values, dtype=jnp.float32),
            jnp.array(visual_attention_mask, dtype="i4"),
            jnp.array(visual_token_type_ids, dtype="i4"),
            jnp.array(visual_position_ids, dtype="i4"),
            not train,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
        )

    @classmethod
    def from_bert_clip_vision_pretrained(
        cls,
        bert_model_name_or_path: str = None,
        clip_vision_model_name_or_path: str = None,
        *model_args,
        **kwargs,
    ) -> FlaxPreTrainedModel:

        kwargs_bert = {
            argument[len("bert_") :]: value
            for argument, value in kwargs.items()
            if argument.startswith("text_")
        }

        kwargs_clip_vision = {
            argument[len("clip_vision_") :]: value
            for argument, value in kwargs.items()
            if argument.startswith("vision_")
        }

        # remove text, vision kwargs from kwargs
        for key in kwargs_bert.keys():
            del kwargs["bert_" + key]
        for key in kwargs_clip_vision.keys():
            del kwargs["clip_vision_" + key]

        # Load and initialize the text and vision model
        bert_model = kwargs_bert.pop("model", None)
        if bert_model is None:
            assert (
                bert_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `bert_model_name_or_path` has to be defined"
            from transformers import FlaxBertModel

            if "config" not in kwargs_bert:
                from transformers import BertConfig

                bert_config = BertConfig.from_pretrained(bert_model_name_or_path)
                kwargs_bert["config"] = bert_config

            bert_model = FlaxBertModel.from_pretrained(
                bert_model_name_or_path, *model_args, from_pt=True, **kwargs_bert
            )

        clip_vision_model = kwargs_clip_vision.pop("model", None)
        if clip_vision_model is None:
            assert (
                clip_vision_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `clip_vision_model_name_or_path` has to be defined"
            from transformers import FlaxCLIPVisionModel

            if "config" not in kwargs_clip_vision:
                from transformers import CLIPVisionConfig

                clip_vision_config = CLIPVisionConfig.from_pretrained(
                    clip_vision_model_name_or_path
                )
                kwargs_clip_vision["config"] = clip_vision_config

            clip_vision_model = FlaxCLIPVisionModel.from_pretrained(
                clip_vision_model_name_or_path, *model_args, **kwargs_clip_vision
            )

        # instantiate config with corresponding kwargs
        dtype = kwargs.pop("dtype", jnp.float32)
        config = CLIPVisionBertConfig.from_bert_clip_vision_configs(
            bert_model.config, clip_vision_model.config, **kwargs
        )

        # init model
        model = cls(config, *model_args, dtype=dtype, **kwargs)

        for key in model.params.keys():
            if key != "embeddings":
                model.params[key] = bert_model.params[key]
            else:
                model.params["embeddings"][
                    "clip_vision_module"
                ] = clip_vision_model.params
                for sub_key in bert_model.params[key]:
                    model.params[key][sub_key] = bert_model.params[key][sub_key]

        return model


# flax_model = FlaxCLIPVisionBertModel.from_bert_clip_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32', seed=0, dtype=jnp.float32)
# outputs = flax_model(input_ids, attention_mask,token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, output_hidden_states=True)


class FlaxCLIPVisionBertForMaskedLMModule(nn.Module):
    config: CLIPVisionBertConfig
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.model = FlaxCLIPVisionBertModule(
            config=self.config, add_pooling_layer=False, dtype=self.dtype
        )
        self.cls = FlaxBertOnlyMLMHead(config=self.config.bert_config, dtype=self.dtype)

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        pixel_values,
        visual_attention_mask,
        visual_token_type_ids,
        visual_position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):

        # Model
        outputs = self.model(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            pixel_values,
            visual_attention_mask,
            visual_token_type_ids,
            visual_position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        if self.config.bert_config.tie_word_embeddings:
            shared_embedding = self.model.variables["params"]["embeddings"][
                "word_embeddings"
            ]["embedding"]
        else:
            shared_embedding = None

        # Compute the prediction scores
        logits = self.cls(hidden_states, shared_embedding=shared_embedding)

        if not return_dict:
            return (logits,) + outputs[1:]

        return FlaxMaskedLMOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class FlaxCLIPVisionBertForMaskedLM(FlaxPreTrainedModel):
    config_class = CLIPVisionBertConfig
    module_class = FlaxCLIPVisionBertForMaskedLMModule

    def __init__(
        self,
        config: CLIPVisionBertConfig,
        input_shape: Tuple = None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        **kwargs,
    ):

        if input_shape is None:
            input_shape = (
                (1, 1),
                (
                    1,
                    config.clip_vision_config.image_size,
                    config.clip_vision_config.image_size,
                    3,
                ),
                (
                    1,
                    (
                        config.clip_vision_config.image_size
                        // config.clip_vision_config.patch_size
                    )
                    ** 2
                    + 1,
                ),
            )

        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(
            config, module, input_shape=input_shape, seed=seed, dtype=dtype
        )

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
        # init input tensors
        textual_input_shape = input_shape[0]
        input_ids = jnp.zeros(textual_input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), textual_input_shape
        )
        attention_mask = jnp.ones_like(input_ids)

        pixel_values = jax.random.normal(rng, input_shape[1])
        visual_attention_mask = jnp.ones(input_shape[2])
        visual_token_type_ids = jnp.ones(input_shape[2])
        visual_position_ids = jnp.broadcast_to(
            jnp.zeros(jnp.atleast_2d(visual_token_type_ids).shape[-1]), input_shape[2]
        )

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        return self.module.init(
            rngs,
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            pixel_values,
            visual_attention_mask,
            visual_token_type_ids,
            visual_position_ids,
            return_dict=False,
        )["params"]

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        pixel_values=None,
        visual_attention_mask=None,
        visual_token_type_ids=None,
        visual_position_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.bert_config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.bert_config.output_hidden_states
        )
        return_dict = (
            return_dict
            if return_dict is not None
            else self.config.bert_config.return_dict
        )

        # pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        # init input tensors if not passed
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        if position_ids is None:
            position_ids = jnp.broadcast_to(
                jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
            )

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        visual_sequence_length = (
            pixel_values.shape[0],
            (
                self.config.clip_vision_config.image_size
                // self.config.clip_vision_config.patch_size
            )
            ** 2
            + 1,
        )

        if visual_token_type_ids is None:
            visual_token_type_ids = jnp.ones(visual_sequence_length)

        if visual_position_ids is None:
            visual_position_ids = jnp.broadcast_to(
                jnp.atleast_2d(jnp.ones(visual_sequence_length)).shape[-1],
                (visual_sequence_length),
            )

        if visual_attention_mask is None:
            visual_attention_mask = jnp.ones(visual_sequence_length)

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(token_type_ids, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            jnp.array(pixel_values, dtype=jnp.float32),
            jnp.array(visual_attention_mask, dtype="i4"),
            jnp.array(visual_token_type_ids, dtype="i4"),
            jnp.array(visual_position_ids, dtype="i4"),
            not train,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
        )

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        # At the moment fast initialization is not supported
        # for composite models
        # kwargs["_fast_init"] = False
        return super().from_pretrained(*args, **kwargs)

    @classmethod
    def from_clip_vision_bert_pretrained(
        cls,
        clip_vision_model_name_or_path: str = None,
        bert_model_name_or_path: str = None,
        *model_args,
        **kwargs,
    ) -> FlaxPreTrainedModel:

        kwargs_bert = {
            argument[len("bert_") :]: value
            for argument, value in kwargs.items()
            if argument.startswith("text_")
        }

        kwargs_clip_vision = {
            argument[len("clip_vision_") :]: value
            for argument, value in kwargs.items()
            if argument.startswith("vision_")
        }

        # remove text, vision kwargs from kwargs
        for key in kwargs_bert.keys():
            del kwargs["bert_" + key]
        for key in kwargs_clip_vision.keys():
            del kwargs["clip_vision_" + key]

        # Load and initialize the text and vision model
        bert_model = kwargs_bert.pop("model", None)
        if bert_model is None:
            assert (
                bert_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `bert_model_name_or_path` has to be defined"
            from transformers import FlaxBertForMaskedLM

            if "config" not in kwargs_bert:
                from transformers import BertConfig

                bert_config = BertConfig.from_pretrained(bert_model_name_or_path)
                kwargs_bert["config"] = bert_config

            bert_model = FlaxBertForMaskedLM.from_pretrained(
                bert_model_name_or_path, *model_args, from_pt=True, **kwargs_bert
            )

        clip_vision_model = kwargs_clip_vision.pop("model", None)
        if clip_vision_model is None:
            assert (
                clip_vision_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `clip_vision_model_name_or_path` has to be defined"
            from transformers import FlaxCLIPVisionModel

            if "config" not in kwargs_clip_vision:
                from transformers import CLIPVisionConfig

                clip_vision_config = CLIPVisionConfig.from_pretrained(
                    clip_vision_model_name_or_path
                )
                kwargs_clip_vision["config"] = clip_vision_config

            clip_vision_model = FlaxCLIPVisionModel.from_pretrained(
                clip_vision_model_name_or_path, *model_args, **kwargs_clip_vision
            )

        # instantiate config with corresponding kwargs
        dtype = kwargs.pop("dtype", jnp.float32)
        config = CLIPVisionBertConfig.from_clip_vision_bert_configs(
            clip_vision_model.config, bert_model.config, **kwargs
        )

        # init model
        model = cls(config, *model_args, dtype=dtype, **kwargs)

        model.params["cls"] = bert_model.params["cls"]
        for key in model.params["model"].keys():
            if key != "embeddings":
                model.params["model"][key] = bert_model.params["bert"][key]
            else:
                model.params["model"]["embeddings"][
                    "clip_vision_module"
                ] = clip_vision_model.params
                for sub_key in bert_model.params["bert"][key]:
                    model.params["model"][key][sub_key] = bert_model.params["bert"][
                        key
                    ][sub_key]

        return model


class FlaxCLIPVisionBertForSequenceClassificationModule(nn.Module):
    config: CLIPVisionBertConfig
    dtype: jnp.dtype = jnp.float32
    num_labels: int = 3129  # TODO: Remove this hard-coding!

    def setup(self):
        self.model = FlaxCLIPVisionBertModule(config=self.config, dtype=self.dtype)
        self.dropout = nn.Dropout(rate=self.config.bert_config.hidden_dropout_prob)
        self.classifier = nn.Dense(
            self.num_labels,
            dtype=self.dtype,
        )

    def __call__(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        position_ids,
        pixel_values,
        visual_attention_mask,
        visual_token_type_ids,
        visual_position_ids,
        deterministic: bool = True,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ):
        # Model
        outputs = self.model(
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            pixel_values,
            visual_attention_mask,
            visual_token_type_ids,
            visual_position_ids,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output, deterministic=deterministic)
        logits = self.classifier(pooled_output)

        if not return_dict:
            return (logits,) + outputs[2:]

        return FlaxSequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


class FlaxCLIPVisionBertForSequenceClassification(FlaxPreTrainedModel):
    config_class = CLIPVisionBertConfig
    module_class = FlaxCLIPVisionBertForSequenceClassificationModule

    def __init__(
        self,
        config: CLIPVisionBertConfig,
        input_shape: Tuple = None,
        seed: int = 0,
        dtype: jnp.dtype = jnp.float32,
        **kwargs,
    ):

        if input_shape is None:
            input_shape = (
                (1, 1),
                (
                    1,
                    config.clip_vision_config.image_size,
                    config.clip_vision_config.image_size,
                    3,
                ),
                (
                    1,
                    (
                        config.clip_vision_config.image_size
                        // config.clip_vision_config.patch_size
                    )
                    ** 2
                    + 1,
                ),
            )

        module = self.module_class(config=config, dtype=dtype, **kwargs)
        super().__init__(
            config, module, input_shape=input_shape, seed=seed, dtype=dtype
        )

    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
        # init input tensors
        textual_input_shape = input_shape[0]
        input_ids = jnp.zeros(textual_input_shape, dtype="i4")
        token_type_ids = jnp.zeros_like(input_ids)
        position_ids = jnp.broadcast_to(
            jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), textual_input_shape
        )
        attention_mask = jnp.ones_like(input_ids)

        pixel_values = jax.random.normal(rng, input_shape[1])
        visual_attention_mask = jnp.ones(input_shape[2])
        visual_token_type_ids = jnp.ones(input_shape[2])
        visual_position_ids = jnp.broadcast_to(
            jnp.zeros(jnp.atleast_2d(visual_token_type_ids).shape[-1]), input_shape[2]
        )

        params_rng, dropout_rng = jax.random.split(rng)
        rngs = {"params": params_rng, "dropout": dropout_rng}

        return self.module.init(
            rngs,
            input_ids,
            attention_mask,
            token_type_ids,
            position_ids,
            pixel_values,
            visual_attention_mask,
            visual_token_type_ids,
            visual_position_ids,
            return_dict=False,
        )["params"]

    def __call__(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        pixel_values=None,
        visual_attention_mask=None,
        visual_token_type_ids=None,
        visual_position_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.bert_config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.bert_config.output_hidden_states
        )
        return_dict = (
            return_dict
            if return_dict is not None
            else self.config.bert_config.return_dict
        )

        # pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        # init input tensors if not passed
        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        if position_ids is None:
            position_ids = jnp.broadcast_to(
                jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
            )

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        visual_sequence_length = (
            pixel_values.shape[0],
            (
                self.config.clip_vision_config.image_size
                // self.config.clip_vision_config.patch_size
            )
            ** 2
            + 1,
        )

        if visual_token_type_ids is None:
            visual_token_type_ids = jnp.ones(visual_sequence_length)

        if visual_position_ids is None:
            visual_position_ids = jnp.broadcast_to(
                jnp.atleast_2d(jnp.ones(visual_sequence_length)).shape[-1],
                (visual_sequence_length),
            )

        if visual_attention_mask is None:
            visual_attention_mask = jnp.ones(visual_sequence_length)

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(token_type_ids, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            jnp.array(pixel_values, dtype=jnp.float32),
            jnp.array(visual_attention_mask, dtype="i4"),
            jnp.array(visual_token_type_ids, dtype="i4"),
            jnp.array(visual_position_ids, dtype="i4"),
            not train,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
        )

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        # At the moment fast initialization is not supported
        # for composite models
        # kwargs["_fast_init"] = False
        return super().from_pretrained(*args, **kwargs)

    @classmethod
    def from_clip_vision_bert_pretrained(
        cls,
        clip_vision_model_name_or_path: str = None,
        bert_model_name_or_path: str = None,
        *model_args,
        **kwargs,
    ) -> FlaxPreTrainedModel:

        kwargs_bert = {
            argument[len("bert_") :]: value
            for argument, value in kwargs.items()
            if argument.startswith("bert_")
        }

        kwargs_clip_vision = {
            argument[len("clip_vision_") :]: value
            for argument, value in kwargs.items()
            if argument.startswith("clip_vision_")
        }

        # remove text, vision kwargs from kwargs
        for key in kwargs_bert.keys():
            del kwargs["bert_" + key]
        for key in kwargs_clip_vision.keys():
            del kwargs["clip_vision_" + key]

        # Load and initialize the text and vision model
        bert_model = kwargs_bert.pop("model", None)
        if bert_model is None:
            assert (
                bert_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `bert_model_name_or_path` has to be defined"
            from transformers import FlaxBertForSequenceClassification

            if "config" not in kwargs_bert:
                from transformers import BertConfig

                bert_config = BertConfig.from_pretrained(bert_model_name_or_path)
                kwargs_bert["config"] = bert_config

            bert_model = FlaxBertForSequenceClassification.from_pretrained(
                bert_model_name_or_path, *model_args, from_pt=True, **kwargs_bert
            )

        clip_vision_model = kwargs_clip_vision.pop("model", None)
        if clip_vision_model is None:
            assert (
                clip_vision_model_name_or_path is not None
            ), "If `model` is not defined as an argument, a `clip_vision_model_name_or_path` has to be defined"
            from transformers import FlaxCLIPVisionModel

            if "config" not in kwargs_clip_vision:
                from transformers import CLIPVisionConfig

                clip_vision_config = CLIPVisionConfig.from_pretrained(
                    clip_vision_model_name_or_path
                )
                kwargs_clip_vision["config"] = clip_vision_config

            clip_vision_model = FlaxCLIPVisionModel.from_pretrained(
                clip_vision_model_name_or_path, *model_args, **kwargs_clip_vision
            )

        # instantiate config with corresponding kwargs
        dtype = kwargs.pop("dtype", jnp.float32)
        config = CLIPVisionBertConfig.from_clip_vision_bert_configs(
            clip_vision_model.config, bert_model.config, **kwargs
        )

        # init model
        model = cls(config, *model_args, dtype=dtype, **kwargs)

        # model.params["classifier"] = bert_model.params["classifier"]
        for key in model.params["model"].keys():
            if key != "embeddings":
                model.params["model"][key] = bert_model.params["bert"][key]
            else:
                model.params["model"]["embeddings"][
                    "clip_vision_module"
                ] = clip_vision_model.params
                for sub_key in bert_model.params["bert"][key]:
                    model.params["model"][key][sub_key] = bert_model.params["bert"][
                        key
                    ][sub_key]

        return model