Spaces:
Running
on
T4
Running
on
T4
Grounded-Segment-Anything
/
transformers_4_35_0
/models
/clip
/convert_clip_original_pytorch_to_hf.py
# coding=utf-8 | |
# Copyright 2021 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import torch | |
from clip import load | |
from transformers import CLIPConfig, CLIPModel | |
def copy_attn_layer(hf_attn_layer, pt_attn_layer): | |
q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0) | |
q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0) | |
out_proj_weights = pt_attn_layer.out_proj.weight | |
out_proj_bias = pt_attn_layer.out_proj.bias | |
hf_attn_layer.q_proj.weight.data = q_proj | |
hf_attn_layer.q_proj.bias.data = q_proj_bias | |
hf_attn_layer.k_proj.weight.data = k_proj | |
hf_attn_layer.k_proj.bias.data = k_proj_bias | |
hf_attn_layer.v_proj.weight.data = v_proj | |
hf_attn_layer.v_proj.bias.data = v_proj_bias | |
hf_attn_layer.out_proj.weight = out_proj_weights | |
hf_attn_layer.out_proj.bias = out_proj_bias | |
def copy_mlp(hf_mlp, pt_mlp): | |
copy_linear(hf_mlp.fc1, pt_mlp.c_fc) | |
copy_linear(hf_mlp.fc2, pt_mlp.c_proj) | |
def copy_linear(hf_linear, pt_linear): | |
hf_linear.weight = pt_linear.weight | |
hf_linear.bias = pt_linear.bias | |
def copy_layer(hf_layer, pt_layer): | |
# copy layer norms | |
copy_linear(hf_layer.layer_norm1, pt_layer.ln_1) | |
copy_linear(hf_layer.layer_norm2, pt_layer.ln_2) | |
# copy MLP | |
copy_mlp(hf_layer.mlp, pt_layer.mlp) | |
# copy attn | |
copy_attn_layer(hf_layer.self_attn, pt_layer.attn) | |
def copy_layers(hf_layers, pt_layers): | |
for hf_layer, pt_layer in zip(hf_layers, pt_layers): | |
copy_layer(hf_layer, pt_layer) | |
def copy_encoder(hf_encoder, pt_model): | |
# copy embeds | |
hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight | |
hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding | |
# copy layer norm | |
copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final) | |
# copy hidden layers | |
copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks) | |
def copy_text_model_and_projection(hf_model, pt_model): | |
# copy projection | |
hf_model.text_projection.weight.data = pt_model.text_projection.data.T | |
# copy text encoder | |
copy_encoder(hf_model.text_model, pt_model) | |
def copy_vison_model_and_projection(hf_model, pt_model): | |
# copy projection | |
hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T | |
# copy layer norms | |
copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre) | |
copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post) | |
# copy embeds | |
hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data | |
hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding | |
hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data | |
# copy encoder | |
copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks) | |
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): | |
""" | |
Copy/paste/tweak model's weights to transformers design. | |
""" | |
if config_path is not None: | |
config = CLIPConfig.from_pretrained(config_path) | |
else: | |
config = CLIPConfig(projection_dim=512, text_config={}, vision_config={}) | |
hf_model = CLIPModel(config).eval() | |
pt_model, _ = load(checkpoint_path, device="cpu", jit=False) | |
pt_model = pt_model.eval() | |
copy_text_model_and_projection(hf_model, pt_model) | |
copy_vison_model_and_projection(hf_model, pt_model) | |
hf_model.logit_scale = pt_model.logit_scale | |
input_ids = torch.arange(0, 77).unsqueeze(0) | |
pixel_values = torch.randn(1, 3, 224, 224) | |
hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True) | |
hf_logits_per_image = hf_outputs.logits_per_image | |
hf_logits_per_text = hf_outputs.logits_per_text | |
pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids) | |
assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3) | |
assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3) | |
hf_model.save_pretrained(pytorch_dump_folder_path) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") | |
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") | |
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") | |
args = parser.parse_args() | |
convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) | |