Grounded-Segment-Anything / transformers_4_35_0 /models /clip /convert_clip_original_pytorch_to_hf.py
liuyizhang
add transformers_4_35_0
1ce5e18
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from clip import load
from transformers import CLIPConfig, CLIPModel
def copy_attn_layer(hf_attn_layer, pt_attn_layer):
q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0)
q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0)
out_proj_weights = pt_attn_layer.out_proj.weight
out_proj_bias = pt_attn_layer.out_proj.bias
hf_attn_layer.q_proj.weight.data = q_proj
hf_attn_layer.q_proj.bias.data = q_proj_bias
hf_attn_layer.k_proj.weight.data = k_proj
hf_attn_layer.k_proj.bias.data = k_proj_bias
hf_attn_layer.v_proj.weight.data = v_proj
hf_attn_layer.v_proj.bias.data = v_proj_bias
hf_attn_layer.out_proj.weight = out_proj_weights
hf_attn_layer.out_proj.bias = out_proj_bias
def copy_mlp(hf_mlp, pt_mlp):
copy_linear(hf_mlp.fc1, pt_mlp.c_fc)
copy_linear(hf_mlp.fc2, pt_mlp.c_proj)
def copy_linear(hf_linear, pt_linear):
hf_linear.weight = pt_linear.weight
hf_linear.bias = pt_linear.bias
def copy_layer(hf_layer, pt_layer):
# copy layer norms
copy_linear(hf_layer.layer_norm1, pt_layer.ln_1)
copy_linear(hf_layer.layer_norm2, pt_layer.ln_2)
# copy MLP
copy_mlp(hf_layer.mlp, pt_layer.mlp)
# copy attn
copy_attn_layer(hf_layer.self_attn, pt_layer.attn)
def copy_layers(hf_layers, pt_layers):
for hf_layer, pt_layer in zip(hf_layers, pt_layers):
copy_layer(hf_layer, pt_layer)
def copy_encoder(hf_encoder, pt_model):
# copy embeds
hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight
hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding
# copy layer norm
copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final)
# copy hidden layers
copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks)
def copy_text_model_and_projection(hf_model, pt_model):
# copy projection
hf_model.text_projection.weight.data = pt_model.text_projection.data.T
# copy text encoder
copy_encoder(hf_model.text_model, pt_model)
def copy_vison_model_and_projection(hf_model, pt_model):
# copy projection
hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T
# copy layer norms
copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre)
copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post)
# copy embeds
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data
hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data
# copy encoder
copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks)
@torch.no_grad()
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
"""
Copy/paste/tweak model's weights to transformers design.
"""
if config_path is not None:
config = CLIPConfig.from_pretrained(config_path)
else:
config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})
hf_model = CLIPModel(config).eval()
pt_model, _ = load(checkpoint_path, device="cpu", jit=False)
pt_model = pt_model.eval()
copy_text_model_and_projection(hf_model, pt_model)
copy_vison_model_and_projection(hf_model, pt_model)
hf_model.logit_scale = pt_model.logit_scale
input_ids = torch.arange(0, 77).unsqueeze(0)
pixel_values = torch.randn(1, 3, 224, 224)
hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
hf_logits_per_image = hf_outputs.logits_per_image
hf_logits_per_text = hf_outputs.logits_per_text
pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)
hf_model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
args = parser.parse_args()
convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)