# coding=utf-8 # Copyright 2021 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import torch from clip import load from transformers import CLIPConfig, CLIPModel def copy_attn_layer(hf_attn_layer, pt_attn_layer): q_proj, k_proj, v_proj = pt_attn_layer.in_proj_weight.chunk(3, dim=0) q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj_bias.chunk(3, dim=0) out_proj_weights = pt_attn_layer.out_proj.weight out_proj_bias = pt_attn_layer.out_proj.bias hf_attn_layer.q_proj.weight.data = q_proj hf_attn_layer.q_proj.bias.data = q_proj_bias hf_attn_layer.k_proj.weight.data = k_proj hf_attn_layer.k_proj.bias.data = k_proj_bias hf_attn_layer.v_proj.weight.data = v_proj hf_attn_layer.v_proj.bias.data = v_proj_bias hf_attn_layer.out_proj.weight = out_proj_weights hf_attn_layer.out_proj.bias = out_proj_bias def copy_mlp(hf_mlp, pt_mlp): copy_linear(hf_mlp.fc1, pt_mlp.c_fc) copy_linear(hf_mlp.fc2, pt_mlp.c_proj) def copy_linear(hf_linear, pt_linear): hf_linear.weight = pt_linear.weight hf_linear.bias = pt_linear.bias def copy_layer(hf_layer, pt_layer): # copy layer norms copy_linear(hf_layer.layer_norm1, pt_layer.ln_1) copy_linear(hf_layer.layer_norm2, pt_layer.ln_2) # copy MLP copy_mlp(hf_layer.mlp, pt_layer.mlp) # copy attn copy_attn_layer(hf_layer.self_attn, pt_layer.attn) def copy_layers(hf_layers, pt_layers): for hf_layer, pt_layer in zip(hf_layers, pt_layers): copy_layer(hf_layer, pt_layer) def copy_encoder(hf_encoder, pt_model): # copy embeds hf_encoder.embeddings.token_embedding.weight = pt_model.token_embedding.weight hf_encoder.embeddings.position_embedding.weight.data = pt_model.positional_embedding # copy layer norm copy_linear(hf_encoder.final_layer_norm, pt_model.ln_final) # copy hidden layers copy_layers(hf_encoder.encoder.layers, pt_model.transformer.resblocks) def copy_text_model_and_projection(hf_model, pt_model): # copy projection hf_model.text_projection.weight.data = pt_model.text_projection.data.T # copy text encoder copy_encoder(hf_model.text_model, pt_model) def copy_vison_model_and_projection(hf_model, pt_model): # copy projection hf_model.visual_projection.weight.data = pt_model.visual.proj.data.T # copy layer norms copy_linear(hf_model.vision_model.pre_layrnorm, pt_model.visual.ln_pre) copy_linear(hf_model.vision_model.post_layernorm, pt_model.visual.ln_post) # copy embeds hf_model.vision_model.embeddings.patch_embedding.weight.data = pt_model.visual.conv1.weight.data hf_model.vision_model.embeddings.class_embedding = pt_model.visual.class_embedding hf_model.vision_model.embeddings.position_embedding.weight.data = pt_model.visual.positional_embedding.data # copy encoder copy_layers(hf_model.vision_model.encoder.layers, pt_model.visual.transformer.resblocks) @torch.no_grad() def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): """ Copy/paste/tweak model's weights to transformers design. """ if config_path is not None: config = CLIPConfig.from_pretrained(config_path) else: config = CLIPConfig(projection_dim=512, text_config={}, vision_config={}) hf_model = CLIPModel(config).eval() pt_model, _ = load(checkpoint_path, device="cpu", jit=False) pt_model = pt_model.eval() copy_text_model_and_projection(hf_model, pt_model) copy_vison_model_and_projection(hf_model, pt_model) hf_model.logit_scale = pt_model.logit_scale input_ids = torch.arange(0, 77).unsqueeze(0) pixel_values = torch.randn(1, 3, 224, 224) hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True) hf_logits_per_image = hf_outputs.logits_per_image hf_logits_per_text = hf_outputs.logits_per_text pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids) assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3) assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3) hf_model.save_pretrained(pytorch_dump_folder_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") args = parser.parse_args() convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)