File size: 2,511 Bytes
117ca20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import torch
from transformers import CLIPConfig, CLIPModel
from transformers.models.clip.convert_clip_original_pytorch_to_hf import copy_text_model_and_projection, copy_vison_model_and_projection
from clip.clip import build_model

@torch.no_grad()
def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path=None):
    """
    Copy/paste/tweak model's weights to transformers design. Adapted from
    https://github.com/huggingface/transformers/blob/3723329d014a7b144863e597ea4fe6de5e6a8279/src/transformers/models/clip/convert_clip_original_pytorch_to_hf.py#LL108C1-L138C55
    """
    if config_path is not None:
        config = CLIPConfig.from_pretrained(config_path)
    else:
        config = CLIPConfig(projection_dim=512, text_config={}, vision_config={})

    hf_model = CLIPModel(config).eval()

    # Load the pre-trained checkpoint, this can be downloaded from the OneDrive link shared by the authors: https://1drv.ms/u/s!ApXgPqe9kykTgwD4Np3-f7ODAot8?e=zLVlJ2
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    pt_model = build_model(checkpoint["state_dict"])
    pt_model = pt_model.float()
    pt_model = pt_model.eval()

    copy_text_model_and_projection(hf_model, pt_model)
    copy_vison_model_and_projection(hf_model, pt_model)
    hf_model.logit_scale = pt_model.logit_scale

    input_ids = torch.arange(0, 77).unsqueeze(0)
    pixel_values = torch.randn(1, 3, 224, 224)

    hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
    hf_logits_per_image = hf_outputs.logits_per_image
    hf_logits_per_text = hf_outputs.logits_per_text
    pt_logits_per_image, pt_logits_per_text = pt_model(pixel_values, input_ids)

    assert torch.allclose(hf_logits_per_image, pt_logits_per_image, atol=1e-3)
    assert torch.allclose(hf_logits_per_text, pt_logits_per_text, atol=1e-3)

    hf_model.save_pretrained(pytorch_dump_folder_path)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
    parser.add_argument("--checkpoint_path", default="PubMedCLIP_ViT32.pth", type=str, help="Path to PubMedCLIP ViT32 checkpoint")
    parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
    args = parser.parse_args()
    convert_clip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path)