File size: 2,188 Bytes
7f962d6
 
 
 
 
 
 
 
803c7df
7f962d6
 
 
 
 
 
 
 
 
 
 
b7d8724
7f962d6
 
 
803c7df
7f962d6
b7d8724
7f962d6
803c7df
 
 
7f962d6
 
b7d8724
7f962d6
803c7df
 
 
 
 
 
 
 
7f962d6
803c7df
 
 
7f962d6
 
 
803c7df
7f962d6
b7d8724
7f962d6
803c7df
 
 
 
7f962d6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import jax
import flax.linen as nn

from transformers.models.bart.modeling_flax_bart import (
    FlaxBartModule,
    FlaxBartForConditionalGenerationModule,
    FlaxBartForConditionalGeneration,
    FlaxBartEncoder,
    FlaxBartDecoder,
)

from transformers import BartConfig


class CustomFlaxBartModule(FlaxBartModule):
    def setup(self):
        # we keep shared to easily load pre-trained weights
        self.shared = nn.Embed(
            self.config.vocab_size,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
        )
        # a separate embedding is used for the decoder
        self.decoder_embed = nn.Embed(
            self.config.image_vocab_size + 1,
            self.config.d_model,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
        )
        self.encoder = FlaxBartEncoder(
            self.config, dtype=self.dtype, embed_tokens=self.shared
        )

        # the decoder has a different config
        # TODO: should not be needed once we have custom config/module
        decoder_config = BartConfig(self.config.to_dict())
        decoder_config.max_position_embeddings = (
            self.config.image_length + 1  # image tokens + BOS
        )
        decoder_config.vocab_size = self.config.image_vocab_size + 1
        self.decoder = FlaxBartDecoder(
            decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
        )


class CustomFlaxBartForConditionalGenerationModule(
    FlaxBartForConditionalGenerationModule
):
    def setup(self):
        self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
        self.lm_head = nn.Dense(
            self.config.image_vocab_size + 1,  # encoded image token space + 1 for bos
            use_bias=False,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )
        self.final_logits_bias = self.param(
            "final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
        )


class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
    module_class = CustomFlaxBartForConditionalGenerationModule