Image-Text-to-Text
Transformers
Safetensors
English
idefics2
pretraining
multimodal
vision
Inference Endpoints
5 papers

idefics2-8b-init?

#5
by giobin - opened

My compliments for the great work you have been doing with idefics2 (and IDEFICS before it)! Is it possible to have the checkpoint of idefics2 even before the pretraining phase (before idefics2-8b-base)? that would help people trying to "reproduce" at least part of the training. Basically i am asking for the initialization code or weights of the idefics2 modality projection layers. That would be great!
Thanks!

giobin changed discussion status to closed
giobin changed discussion status to open
HuggingFaceM4 org

Hi @giobin , here is our code for the initialization of the modules

def _init_weights(self, module):
        def init_a_linear(module, mean=0.0, std=self.config.initializer_range):
            with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
                module.weight.data.normal_(mean=mean, std=std)
                if module.bias is not None:
                    with ContextManagers(deepspeed_gathered_parameters_context_manager(module.bias, modify=True)):
                        module.bias.data.zero_()

        if isinstance(module, MLP):
            for sub_module_name, sub_module in module.named_modules():
                if isinstance(sub_module, nn.Linear):
                    factor = 1.0
                    if "down_proj" in sub_module_name:
                        factor = 2.0
                    init_a_linear(sub_module, std=(0.4 / (self.config.hidden_size * factor)) ** 0.5)

        if isinstance(module, PerceiverResampler):
            with ContextManagers(deepspeed_gathered_parameters_context_manager(module.latents, modify=True)):
                module.latents.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
            for sub_module_name, sub_module in module.named_modules():
                if isinstance(sub_module, nn.Linear):
                    factor = 1.0
                    if "o_proj" in sub_module_name:
                        factor = 2.0 * self.config.perceiver_config.resampler_depth
                    init_a_linear(sub_module, std=(0.4 / (self.config.hidden_size * factor)) ** 0.5)

        elif isinstance(module, nn.Embedding):
            with ContextManagers(deepspeed_gathered_parameters_context_manager(module.weight, modify=True)):
                module.weight.data.normal_(mean=0.0, std=(1.0 / self.config.hidden_size) ** 0.5)
                if module.padding_idx is not None:
                    module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, DecoupledLinear):
            if hasattr(module, "additional_fc"):
                init_a_linear(module.additional_fc, std=(1.0 / (module.additional_fc.in_features)) ** 0.5)

Sign up or log in to comment