# coding=utf-8 # Copyright 2022 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import re import requests import torch # git clone https://github.com/salesforce/BLIP.git from models.blip import blip_decoder from models.blip_itm import blip_itm from models.blip_vqa import blip_vqa from PIL import Image from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers import ( BertTokenizer, BlipConfig, BlipForConditionalGeneration, BlipForImageTextRetrieval, BlipForQuestionAnswering, ) def load_demo_image(image_size, device): img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg" raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") transform = transforms.Compose( [ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ] ) image = transform(raw_image).unsqueeze(0).to(device) return image def rename_key(key): if "visual_encoder" in key: key = re.sub("visual_encoder*", "vision_model.encoder", key) if "blocks" in key: key = re.sub(r"blocks", "layers", key) if "attn" in key: key = re.sub(r"attn", "self_attn", key) if "norm1" in key: key = re.sub(r"norm1", "layer_norm1", key) if "norm2" in key: key = re.sub(r"norm2", "layer_norm2", key) if "encoder.norm" in key: key = re.sub(r"encoder.norm", "post_layernorm", key) if "encoder.patch_embed.proj" in key: key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key) if "encoder.pos_embed" in key: key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key) if "encoder.cls_token" in key: key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key) if "self_attn" in key: key = re.sub(r"self_attn.proj", "self_attn.projection", key) return key @torch.no_grad() def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None): """ Copy/paste/tweak model's weights to transformers design. """ if config_path is not None: config = BlipConfig.from_pretrained(config_path) else: config = BlipConfig(projection_dim=512, text_config={}, vision_config={}) hf_model = BlipForConditionalGeneration(config).eval() model_url = "model_base_capfilt_large.pth" # pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base") # pt_model = pt_model.eval() # modified_state_dict = pt_model.state_dict() # for key in modified_state_dict.copy(): # value = modified_state_dict.pop(key) # renamed_key = rename_key(key) # modified_state_dict[renamed_key] = value # # hf_model.load_state_dict(modified_state_dict) # image_size = 384 image = load_demo_image(image_size=image_size, device="cpu") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # input_ids = tokenizer(["a picture of"]).input_ids # # out = hf_model.generate(image, input_ids) # # assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102] # # out = hf_model.generate(image) # # assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102] # # if pytorch_dump_folder_path is not None: # hf_model.save_pretrained(pytorch_dump_folder_path) # # # model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth' # model_url = ( # # "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth" # # ) vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base") vqa_model.eval() modified_state_dict = vqa_model.state_dict() for key in modified_state_dict.copy(): value = modified_state_dict.pop(key) renamed_key = rename_key(key) modified_state_dict[renamed_key] = value hf_vqa_model = BlipForQuestionAnswering(config) offset_keys = [i for i in modified_state_dict.keys() if i not in hf_vqa_model.state_dict().keys()] print(len([i for i in hf_vqa_model.state_dict().keys() if i in modified_state_dict.keys()])) for key in offset_keys: modified_state_dict.pop(key) hf_vqa_model.load_state_dict(modified_state_dict) question = ["How many dogs are in this image?"] question_input_ids = tokenizer(question, return_tensors="pt").input_ids answer = hf_vqa_model.generate(question_input_ids, image) print(tokenizer.decode(answer[0])) # assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]" if pytorch_dump_folder_path is not None: hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa") # model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth" # # itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base") # itm_model.eval() # # modified_state_dict = itm_model.state_dict() # for key in modified_state_dict.copy(): # value = modified_state_dict.pop(key) # renamed_key = rename_key(key) # modified_state_dict[renamed_key] = value # # hf_itm_model = BlipForImageTextRetrieval(config) # # question = ["A picture of a woman with a dog sitting in a beach"] # question_input_ids = tokenizer( # question, # return_tensors="pt", # padding="max_length", # truncation=True, # max_length=35, # ).input_ids # # hf_itm_model.load_state_dict(modified_state_dict) # hf_itm_model.eval() # # out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True) # out = hf_itm_model(question_input_ids, image, use_itm_head=False) # # assert out[0].item() == 0.2110687494277954 # assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127 # # if pytorch_dump_folder_path is not None: # hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") args = parser.parse_args() convert_blip_checkpoint(args.pytorch_dump_folder_path, args.config_path)