import argparse import torch from transformers import CLIPConfig, CLIPModel from transformers.models.clip.convert_clip_original_pytorch_to_hf import copy_text_model_and_projection, copy_vison_model_and_projection from clip.clip import build_model @torch.no_grad() def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None): """ Copy/paste/tweak model's weights to transformers design. Adapted from https://github.com/huggingface/transformers/blob/3723329d014a7b144863e597ea4fe6de5e6a8279/src/transformers/models/clip/convert_clip_original_pytorch_to_hf.py#LL108C1-L138C55 """ if config_path is not None: config = CLIPConfig.from_pretrained(config_path) else: config = CLIPConfig(projection_dim=512, text_config={}, vision_config={}) hf_model = CLIPModel(config).eval() # Load the pre-trained checkpoint, this can be downloaded from the OneDrive link shared by the authors: https://1drv.ms/u/s!ApXgPqe9kykTgwD4Np3-f7ODAot8?e=zLVlJ2 checkpoint = torch.load(checkpoint_path, map_location="cpu") pt_model = build_model(checkpoint["state_dict"]) pt_model = pt_model.float() pt_model = pt_model.eval() copy_text_model_and_projection(hf_model, pt_model) copy_vison_model_and_projection(hf_model, pt_model) hf_model.logit_scale = pt_model.logit_scale input_ids = torch.arange(0, 77).unsqueeze(0) pixel_values = torch.randn(1, 3, 224, 224) hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True) hf_logits_per_image = hf_outputs.logits_per_image hf_logits_per_text = hf_outputs.logits_per_text pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids) assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3) assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3) hf_model.save_pretrained(pytorch_dump_folder_path) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") parser.add_argument("--checkpoint_path", default="PubMedCLIP_ViT32.pth", type=str, help="Path to PubMedCLIP ViT32 checkpoint") parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") args = parser.parse_args() convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)