flaviagiammarino's picture
initial commit
117ca20
raw
history blame
2.51 kB
import argparse
import torch
from transformers import CLIPConfig, CLIPModel
from transformers.models.clip.convert_clip_original_pytorch_to_hf import copy_text_model_and_projection, copy_vison_model_and_projection
from clip.clip import build_model
@torch.no_grad()
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
"""
Copy/paste/tweak model's weights to transformers design. Adapted from
https://github.com/huggingface/transformers/blob/3723329d014a7b144863e597ea4fe6de5e6a8279/src/transformers/models/clip/convert_clip_original_pytorch_to_hf.py#LL108C1-L138C55
"""
if config_path is not None:
config = CLIPConfig.from_pretrained(config_path)
else:
config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})
hf_model = CLIPModel(config).eval()
# Load the pre-trained checkpoint, this can be downloaded from the OneDrive link shared by the authors: https://1drv.ms/u/s!ApXgPqe9kykTgwD4Np3-f7ODAot8?e=zLVlJ2
checkpoint = torch.load(checkpoint_path, map_location="cpu")
pt_model = build_model(checkpoint["state_dict"])
pt_model = pt_model.float()
pt_model = pt_model.eval()
copy_text_model_and_projection(hf_model, pt_model)
copy_vison_model_and_projection(hf_model, pt_model)
hf_model.logit_scale = pt_model.logit_scale
input_ids = torch.arange(0, 77).unsqueeze(0)
pixel_values = torch.randn(1, 3, 224, 224)
hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
hf_logits_per_image = hf_outputs.logits_per_image
hf_logits_per_text = hf_outputs.logits_per_text
pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)
assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)
hf_model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default="PubMedCLIP_ViT32.pth", type=str, help="Path to PubMedCLIP ViT32 checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
args = parser.parse_args()
convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)