from transformers import PretrainedConfig from torchscale.architecture.config import EncoderConfig class ViVQAConfig(PretrainedConfig): model_type = "vivqa" def __init__( self, drop_path_rate: float = 0.0, mlp_ratio: float = 4.0, encoder_layers: int = 6, encoder_attention_heads: int = 6, multiway: bool = True, layernorm_embedding: bool = False, normalize_output: bool = True, no_output_layer: bool = True, encoder_embed_dim: int = 768, **kwargs ): args = EncoderConfig( multiway=multiway, layernorm_embedding=layernorm_embedding, normalize_output=normalize_output, no_output_layer=no_output_layer, drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=encoder_attention_heads, encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=encoder_layers, ) for key, value in args.__dict__.items(): setattr(self, key, value) super().__init__(**kwargs)