Spaces:
Running
on
T4
Running
on
T4
Grounded-Segment-Anything
/
transformers_4_35_0
/models
/blip
/convert_blip_original_pytorch_to_hf.py
# coding=utf-8 | |
# Copyright 2022 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import re | |
import requests | |
import torch | |
# git clone https://github.com/salesforce/BLIP.git | |
from models.blip import blip_decoder | |
from models.blip_itm import blip_itm | |
from models.blip_vqa import blip_vqa | |
from PIL import Image | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode | |
from transformers import ( | |
BertTokenizer, | |
BlipConfig, | |
BlipForConditionalGeneration, | |
BlipForImageTextRetrieval, | |
BlipForQuestionAnswering, | |
) | |
def load_demo_image(image_size, device): | |
img_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg" | |
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
] | |
) | |
image = transform(raw_image).unsqueeze(0).to(device) | |
return image | |
def rename_key(key): | |
if "visual_encoder" in key: | |
key = re.sub("visual_encoder*", "vision_model.encoder", key) | |
if "blocks" in key: | |
key = re.sub(r"blocks", "layers", key) | |
if "attn" in key: | |
key = re.sub(r"attn", "self_attn", key) | |
if "norm1" in key: | |
key = re.sub(r"norm1", "layer_norm1", key) | |
if "norm2" in key: | |
key = re.sub(r"norm2", "layer_norm2", key) | |
if "encoder.norm" in key: | |
key = re.sub(r"encoder.norm", "post_layernorm", key) | |
if "encoder.patch_embed.proj" in key: | |
key = re.sub(r"encoder.patch_embed.proj", "embeddings.patch_embedding", key) | |
if "encoder.pos_embed" in key: | |
key = re.sub(r"encoder.pos_embed", "embeddings.position_embedding", key) | |
if "encoder.cls_token" in key: | |
key = re.sub(r"encoder.cls_token", "embeddings.class_embedding", key) | |
if "self_attn" in key: | |
key = re.sub(r"self_attn.proj", "self_attn.projection", key) | |
return key | |
def convert_blip_checkpoint(pytorch_dump_folder_path, config_path=None): | |
""" | |
Copy/paste/tweak model's weights to transformers design. | |
""" | |
if config_path is not None: | |
config = BlipConfig.from_pretrained(config_path) | |
else: | |
config = BlipConfig(projection_dim=512, text_config={}, vision_config={}) | |
hf_model = BlipForConditionalGeneration(config).eval() | |
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" | |
pt_model = blip_decoder(pretrained=model_url, image_size=384, vit="base") | |
pt_model = pt_model.eval() | |
modified_state_dict = pt_model.state_dict() | |
for key in modified_state_dict.copy(): | |
value = modified_state_dict.pop(key) | |
renamed_key = rename_key(key) | |
modified_state_dict[renamed_key] = value | |
hf_model.load_state_dict(modified_state_dict) | |
image_size = 384 | |
image = load_demo_image(image_size=image_size, device="cpu") | |
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
input_ids = tokenizer(["a picture of"]).input_ids | |
out = hf_model.generate(image, input_ids) | |
assert out[0].tolist() == [30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102] | |
out = hf_model.generate(image) | |
assert out[0].tolist() == [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102] | |
if pytorch_dump_folder_path is not None: | |
hf_model.save_pretrained(pytorch_dump_folder_path) | |
# model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth' | |
model_url = ( | |
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth" | |
) | |
vqa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit="base") | |
vqa_model.eval() | |
modified_state_dict = vqa_model.state_dict() | |
for key in modified_state_dict.copy(): | |
value = modified_state_dict.pop(key) | |
renamed_key = rename_key(key) | |
modified_state_dict[renamed_key] = value | |
hf_vqa_model = BlipForQuestionAnswering(config) | |
hf_vqa_model.load_state_dict(modified_state_dict) | |
question = ["How many dogs are in this image?"] | |
question_input_ids = tokenizer(question, return_tensors="pt").input_ids | |
answer = hf_vqa_model.generate(question_input_ids, image) | |
print(tokenizer.decode(answer[0])) | |
assert tokenizer.decode(answer[0]) == "[UNK] 1 [SEP]" | |
if pytorch_dump_folder_path is not None: | |
hf_vqa_model.save_pretrained(pytorch_dump_folder_path + "_vqa") | |
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth" | |
itm_model = blip_itm(pretrained=model_url, image_size=image_size, vit="base") | |
itm_model.eval() | |
modified_state_dict = itm_model.state_dict() | |
for key in modified_state_dict.copy(): | |
value = modified_state_dict.pop(key) | |
renamed_key = rename_key(key) | |
modified_state_dict[renamed_key] = value | |
hf_itm_model = BlipForImageTextRetrieval(config) | |
question = ["A picture of a woman with a dog sitting in a beach"] | |
question_input_ids = tokenizer( | |
question, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=35, | |
).input_ids | |
hf_itm_model.load_state_dict(modified_state_dict) | |
hf_itm_model.eval() | |
out_itm = hf_itm_model(question_input_ids, image, use_itm_head=True) | |
out = hf_itm_model(question_input_ids, image, use_itm_head=False) | |
assert out[0].item() == 0.2110687494277954 | |
assert torch.nn.functional.softmax(out_itm[0], dim=1)[:, 1].item() == 0.45698845386505127 | |
if pytorch_dump_folder_path is not None: | |
hf_itm_model.save_pretrained(pytorch_dump_folder_path + "_itm") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") | |
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") | |
args = parser.parse_args() | |
convert_blip_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path) | |