# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Convert SigLIP checkpoints from the original repository. URL: https://github.com/google-research/big_vision/tree/main """ import argparse import collections from pathlib import Path import numpy as np import requests import torch from huggingface_hub import hf_hub_download from numpy import load from PIL import Image from configuration_siglip import SiglipConfig from modeling_siglip import SiglipModel from transformers.utils import logging logging.set_verbosity_info() logger = logging.get_logger(__name__) def get_siglip_config(model_name): config = SiglipConfig() # size of the architecture if "base" in model_name: config.vision_config.image_size = 224 config.vision_config.patch_size = 16 config.text_config.vocab_size = 32000 config.text_config.hidden_size = 768 config.text_config.intermediate_size = 3072 config.text_config.max_position_embeddings = 64 config.text_config.num_attention_heads = 12 elif "large" in model_name: config.vision_config.hidden_size = 1024 config.vision_config.num_hidden_layers = 24 config.vision_config.num_attention_heads = 16 elif "so400m" in model_name: config.text_config.vocab_size = 32000 config.text_config.hidden_size = 1152 config.text_config.intermediate_size = 4304 config.text_config.num_hidden_layers = 27 config.text_config.num_attention_heads = 16 config.text_config.max_position_embeddings = 64 config.vision_config.hidden_size = 1152 config.vision_config.intermediate_size = 4304 config.vision_config.num_hidden_layers = 27 config.vision_config.num_attention_heads = 16 config.vision_config.image_size = 384 config.vision_config.patch_size = 14 else: raise ValueError("Model not supported") return config def create_rename_keys(config): rename_keys = [] # fmt: off # vision encoder rename_keys.append(("params/img/embedding/kernel", "vision_model.vision_model.embeddings.patch_embedding.weight")) rename_keys.append(("params/img/embedding/bias", "vision_model.vision_model.embeddings.patch_embedding.bias")) rename_keys.append(("params/img/pos_embedding", "vision_model.vision_model.embeddings.position_embedding.weight")) for i in range(config.vision_config.num_hidden_layers): rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.vision_model.encoder.layers.{i}.layer_norm1.weight")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.vision_model.encoder.layers.{i}.layer_norm1.bias")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.vision_model.encoder.layers.{i}.layer_norm2.weight")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.vision_model.encoder.layers.{i}.layer_norm2.bias")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc1.weight")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc1.bias")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc2.weight")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.vision_model.encoder.layers.{i}.mlp.fc2.bias")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.k_proj.weight")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.k_proj.bias")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.v_proj.weight")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.v_proj.bias")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.q_proj.weight")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.q_proj.bias")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.vision_model.encoder.layers.{i}.self_attn.out_proj.weight")) rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.vision_model.encoder.layers.{i}.self_attn.out_proj.bias")) rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.vision_model.post_layernorm.weight")) rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.vision_model.post_layernorm.bias")) rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.vision_model.head.probe")) rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.vision_model.head.layernorm.weight")) rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.vision_model.head.layernorm.bias")) rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.vision_model.head.mlp.fc1.weight")) rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.vision_model.head.mlp.fc1.bias")) rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.vision_model.head.mlp.fc2.weight")) rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.vision_model.head.mlp.fc2.bias")) rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.vision_model.head.attention.out_proj.weight")) rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.vision_model.head.attention.out_proj.bias")) # text encoder rename_keys.append(("params/txt/Embed_0/embedding", "text_model.text_model.embeddings.token_embedding.weight")) rename_keys.append(("params/txt/pos_embedding", "text_model.text_model.embeddings.position_embedding.weight")) for i in range(config.text_config.num_hidden_layers): rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.text_model.encoder.layers.{i}.layer_norm1.weight")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.text_model.encoder.layers.{i}.layer_norm1.bias")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.text_model.encoder.layers.{i}.layer_norm2.weight")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.text_model.encoder.layers.{i}.layer_norm2.bias")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.text_model.encoder.layers.{i}.mlp.fc1.weight")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.text_model.encoder.layers.{i}.mlp.fc1.bias")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.text_model.encoder.layers.{i}.mlp.fc2.weight")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.text_model.encoder.layers.{i}.mlp.fc2.bias")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.k_proj.weight")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.k_proj.bias")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.v_proj.weight")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.v_proj.bias")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.q_proj.weight")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.q_proj.bias")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.text_model.encoder.layers.{i}.self_attn.out_proj.weight")) rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.text_model.encoder.layers.{i}.self_attn.out_proj.bias")) rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.text_model.final_layer_norm.weight")) rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.text_model.final_layer_norm.bias")) rename_keys.append(("params/txt/head/kernel", "text_model.text_model.head.weight")) rename_keys.append(("params/txt/head/bias", "text_model.text_model.head.bias")) # learned temperature and bias rename_keys.append(("params/t", "temperature")) rename_keys.append(("params/b", "bias")) # fmt: on return rename_keys def rename_key(dct, old, new, config): val = dct.pop(old) if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new: val = val.reshape(-1, config.vision_config.hidden_size) if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new: val = val.reshape(-1, config.text_config.hidden_size) if "patch_embedding.weight" in new: val = val.transpose(3, 2, 0, 1) elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new: val = val.T if "position_embedding" in new and "vision" in new: val = val.reshape(-1, config.vision_config.hidden_size) if "position_embedding" in new and "text" in new: val = val.reshape(-1, config.text_config.hidden_size) if new.endswith("bias"): val = val.reshape(-1) dct[new] = torch.from_numpy(val) def read_in_q_k_v_head(state_dict, config): # read in individual input projection layers key_proj_weight = ( state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel") .reshape(-1, config.vision_config.hidden_size) .T ) key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1) value_proj_weight = ( state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel") .reshape(-1, config.vision_config.hidden_size) .T ) value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1) query_proj_weight = ( state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel") .reshape(-1, config.vision_config.hidden_size) .T ) query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1) # next, add them to the state dict as a single matrix + vector state_dict["vision_model.vision_model.head.attention.in_proj_weight"] = torch.from_numpy( np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0) ) state_dict["vision_model.vision_model.head.attention.in_proj_bias"] = torch.from_numpy( np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0) ) # We will verify our results on an image of cute cats def prepare_img(): url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) return image def flatten_nested_dict(params, parent_key="", sep="/"): items = [] for k, v in params.items(): new_key = parent_key + sep + k if parent_key else k if isinstance(v, collections.abc.MutableMapping): items.extend(flatten_nested_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items) @torch.no_grad() def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False): """ Copy/paste/tweak model's weights to our SigLIP structure. """ # define default SigLIP configuration config = get_siglip_config(model_name) # load original state dict data = load("./webli_en_so400m_384_58765454.npz") state_dict = flatten_nested_dict(data) # remove and rename some keys rename_keys = create_rename_keys(config) for src, dest in rename_keys: rename_key(state_dict, src, dest, config) # qkv matrices of attention pooling head need special treatment read_in_q_k_v_head(state_dict, config) # load HuggingFace model model = SiglipModel(config).eval() model.load_state_dict(state_dict) print("Original temperature:", data["params/t"]) # TODO create image processor # url = "http://images.cocodataset.org/val2017/000000039769.jpg" # image = Image.open(requests.get(url, stream=True).raw).convert("RGB") # preprocess image # # pixel_values = processor(image, return_tensors="pt").pixel_values if pytorch_dump_folder_path is not None: Path(pytorch_dump_folder_path).mkdir(exist_ok=True) print(f"Saving model {model_name} to {pytorch_dump_folder_path}") model.save_pretrained(pytorch_dump_folder_path) # print(f"Saving processor to {pytorch_dump_folder_path}") # processor.save_pretrained(pytorch_dump_folder_path) if push_to_hub: model.push_to_hub(f"HuggingFaceM4/{model_name}") # processor.push_to_hub(f"nielsr/{model_name}") if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--model_name", default="siglip-so400m-14-384", type=str, choices=["siglip-so400m-14-384"], help="Name of the model you'd like to convert.", ) parser.add_argument( "--pytorch_dump_folder_path", default="./", type=str, help="Path to the output PyTorch model directory." ) parser.add_argument( "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." ) args = parser.parse_args() convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)