File size: 1,130 Bytes
065e0ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import PretrainedConfig, AutoConfig


class CLIPEncoderDecoderConfig(PretrainedConfig):
    model_type = "clip-encoder-decoder"

    def __init__(
        self, 
        **kwargs):
        super().__init__(**kwargs)

        self.encoder = AutoConfig.from_pretrained('facebook/convnextv2-base-22k-224')
        self.encoder.hidden_size = 1024
        self.decoder = AutoConfig.from_pretrained('clicknext/phayathaibert')
        self.is_encoder_decoder = True

    @classmethod
    def from_encoder_decoder_configs(
            cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
    ) -> PretrainedConfig:
        r"""
        Instantiate a [`VisionEncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model
        configuration and decoder model configuration.

        Returns:
            [`VisionEncoderDecoderConfig`]: An instance of a configuration object
        """
        decoder_config.is_decoder = True
        decoder_config.add_cross_attention = True

        return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)