Grounded-Segment-Anything
/
transformers_4_35_0
/models
/data2vec
/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py
#!/usr/bin/env python3 | |
import argparse | |
import json | |
import torch | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
from timm.models import create_model | |
from transformers import ( | |
BeitImageProcessor, | |
Data2VecVisionConfig, | |
Data2VecVisionForImageClassification, | |
Data2VecVisionModel, | |
) | |
def create_rename_keys(config, has_lm_head=False, is_semantic=False, hf_prefix="data2vec."): | |
prefix = "backbone." if is_semantic else "" | |
rename_keys = [] | |
for i in range(config.num_hidden_layers): | |
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms | |
rename_keys.append( | |
(f"{prefix}blocks.{i}.norm1.weight", f"{hf_prefix}encoder.layer.{i}.layernorm_before.weight") | |
) | |
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"{hf_prefix}encoder.layer.{i}.layernorm_before.bias")) | |
rename_keys.append( | |
(f"{prefix}blocks.{i}.attn.proj.weight", f"{hf_prefix}encoder.layer.{i}.attention.output.dense.weight") | |
) | |
rename_keys.append( | |
(f"{prefix}blocks.{i}.attn.proj.bias", f"{hf_prefix}encoder.layer.{i}.attention.output.dense.bias") | |
) | |
rename_keys.append( | |
(f"{prefix}blocks.{i}.norm2.weight", f"{hf_prefix}encoder.layer.{i}.layernorm_after.weight") | |
) | |
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"{hf_prefix}encoder.layer.{i}.layernorm_after.bias")) | |
rename_keys.append( | |
(f"{prefix}blocks.{i}.mlp.fc1.weight", f"{hf_prefix}encoder.layer.{i}.intermediate.dense.weight") | |
) | |
rename_keys.append( | |
(f"{prefix}blocks.{i}.mlp.fc1.bias", f"{hf_prefix}encoder.layer.{i}.intermediate.dense.bias") | |
) | |
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"{hf_prefix}encoder.layer.{i}.output.dense.weight")) | |
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"{hf_prefix}encoder.layer.{i}.output.dense.bias")) | |
# projection layer + position embeddings | |
rename_keys.extend( | |
[ | |
(f"{prefix}cls_token", f"{hf_prefix}embeddings.cls_token"), | |
(f"{prefix}patch_embed.proj.weight", f"{hf_prefix}embeddings.patch_embeddings.projection.weight"), | |
(f"{prefix}patch_embed.proj.bias", f"{hf_prefix}embeddings.patch_embeddings.projection.bias"), | |
] | |
) | |
if has_lm_head: | |
# mask token + shared relative position bias + layernorm | |
rename_keys.extend( | |
[ | |
("mask_token", f"{hf_prefix}embeddings.mask_token"), | |
( | |
"rel_pos_bias.relative_position_bias_table", | |
f"{hf_prefix}encoder.relative_position_bias.relative_position_bias_table", | |
), | |
( | |
"rel_pos_bias.relative_position_index", | |
f"{hf_prefix}encoder.relative_position_bias.relative_position_index", | |
), | |
("norm.weight", "layernorm.weight"), | |
("norm.bias", "layernorm.bias"), | |
] | |
) | |
elif is_semantic: | |
# semantic segmentation classification heads | |
rename_keys.extend( | |
[ | |
("decode_head.conv_seg.weight", "decode_head.classifier.weight"), | |
("decode_head.conv_seg.bias", "decode_head.classifier.bias"), | |
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"), | |
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"), | |
] | |
) | |
else: | |
# layernorm + classification head | |
rename_keys.extend( | |
[ | |
("fc_norm.weight", f"{hf_prefix}pooler.layernorm.weight"), | |
("fc_norm.bias", f"{hf_prefix}pooler.layernorm.bias"), | |
("head.weight", "classifier.weight"), | |
("head.bias", "classifier.bias"), | |
] | |
) | |
return rename_keys | |
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False, hf_prefix="data2vec_vision."): | |
for i in range(config.num_hidden_layers): | |
prefix = "backbone." if is_semantic else "" | |
# queries, keys and values | |
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight") | |
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias") | |
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias") | |
state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ | |
: config.hidden_size, : | |
] | |
state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias | |
state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ | |
config.hidden_size : config.hidden_size * 2, : | |
] | |
state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ | |
-config.hidden_size :, : | |
] | |
state_dict[f"{hf_prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias | |
# gamma_1 and gamma_2 | |
# we call them lambda because otherwise they are renamed when using .from_pretrained | |
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1") | |
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2") | |
state_dict[f"{hf_prefix}encoder.layer.{i}.lambda_1"] = gamma_1 | |
state_dict[f"{hf_prefix}encoder.layer.{i}.lambda_2"] = gamma_2 | |
# relative_position bias table + index | |
if not has_lm_head: | |
# each layer has its own relative position bias | |
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table") | |
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index") | |
state_dict[ | |
f"{hf_prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table" | |
] = table | |
state_dict[ | |
f"{hf_prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index" | |
] = index | |
def get_args(): | |
parser = argparse.ArgumentParser( | |
"Convert Data2VecVision to HF for image classification and pretraining", add_help=False | |
) | |
parser.add_argument("--hf_checkpoint_name", type=str) | |
parser.add_argument("--input_size", default=224, type=int, help="images input size") | |
parser.add_argument("--beit_checkpoint", default="", help="beit checkpoint") | |
return parser.parse_args() | |
def load_beit_model(args, is_finetuned, is_large): | |
def load_state_dict(model, state_dict, prefix="", ignore_missing="relative_position_index"): | |
missing_keys = [] | |
unexpected_keys = [] | |
error_msgs = [] | |
# copy state_dict so _load_from_state_dict can modify it | |
metadata = getattr(state_dict, "_metadata", None) | |
state_dict = state_dict.copy() | |
if metadata is not None: | |
state_dict._metadata = metadata | |
def load(module, prefix=""): | |
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
module._load_from_state_dict( | |
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs | |
) | |
for name, child in module._modules.items(): | |
if child is not None: | |
load(child, prefix + name + ".") | |
load(model, prefix=prefix) | |
warn_missing_keys = [] | |
ignore_missing_keys = [] | |
for key in missing_keys: | |
keep_flag = True | |
for ignore_key in ignore_missing.split("|"): | |
if ignore_key in key: | |
keep_flag = False | |
break | |
if keep_flag: | |
warn_missing_keys.append(key) | |
else: | |
ignore_missing_keys.append(key) | |
missing_keys = warn_missing_keys | |
if len(missing_keys) > 0: | |
print( | |
"Weights of {} not initialized from pretrained model: {}".format( | |
model.__class__.__name__, missing_keys | |
) | |
) | |
if len(unexpected_keys) > 0: | |
print("Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)) | |
if len(ignore_missing_keys) > 0: | |
print( | |
"Ignored weights of {} not initialized from pretrained model: {}".format( | |
model.__class__.__name__, ignore_missing_keys | |
) | |
) | |
if len(error_msgs) > 0: | |
print("\n".join(error_msgs)) | |
model_kwargs = { | |
"pretrained": False, | |
"use_shared_rel_pos_bias": True, | |
"use_abs_pos_emb": False, | |
"init_values": 0.1, | |
} | |
if is_finetuned: | |
model_kwargs.update( | |
{ | |
"num_classes": 1000, | |
"use_mean_pooling": True, | |
"init_scale": 0.001, | |
"use_rel_pos_bias": True, | |
} | |
) | |
model = create_model( | |
"beit_large_patch16_224" if is_large else "beit_base_patch16_224", | |
**model_kwargs, | |
) | |
patch_size = model.patch_embed.patch_size | |
args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1]) | |
checkpoint = torch.load(args.beit_checkpoint, map_location="cpu") | |
print(f"Load ckpt from {args.beit_checkpoint}") | |
checkpoint_model = None | |
for model_key in ("model", "module"): | |
if model_key in checkpoint: | |
checkpoint_model = checkpoint[model_key] | |
print(f"Load state_dict by model_key = {model_key}") | |
break | |
all_keys = list(checkpoint_model.keys()) | |
for key in all_keys: | |
if "relative_position_index" in key: | |
checkpoint_model.pop(key) | |
if "relative_position_bias_table" in key: | |
rel_pos_bias = checkpoint_model[key] | |
src_num_pos, num_attn_heads = rel_pos_bias.size() | |
dst_num_pos, _ = model.state_dict()[key].size() | |
dst_patch_shape = model.patch_embed.patch_shape | |
if dst_patch_shape[0] != dst_patch_shape[1]: | |
raise NotImplementedError() | |
load_state_dict(model, checkpoint_model, prefix="") | |
return model | |
def main(): | |
args = get_args() | |
is_finetuned = "ft1k" in args.hf_checkpoint_name | |
is_large = "large" in args.hf_checkpoint_name | |
if is_finetuned: | |
# To convert Beit's data2vec_vision to HF you need to copy | |
# https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_finetune.py | |
# into this folder. | |
import modeling_finetune # noqa: F401 | |
else: | |
# To convert Beit's data2vec_vision to HF you need to copy | |
# https://github.com/facebookresearch/data2vec_vision/blob/main/beit/modeling_cyclical.py | |
# into this folder | |
# IMPORTANT: Note that for now we've only converted the down-stream | |
# model and not the full pretrained model. This means for the integration | |
# test you need to add a `return x` after the following line: | |
# https://github.com/facebookresearch/data2vec_vision/blob/af9a36349aaed59ae66e69b5dabeef2d62fdc5da/beit/modeling_cyclical.py#L197 | |
# to make the integration test pass. | |
import modeling_cyclical # noqa: F401 | |
# 1. Create model config | |
config = Data2VecVisionConfig() | |
if is_finetuned: | |
config.use_relative_position_bias = True | |
config.use_shared_relative_position_bias = False | |
config.use_mean_pooling = True | |
config.num_labels = 1000 | |
repo_id = "huggingface/label-files" | |
filename = "imagenet-1k-id2label.json" | |
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) | |
id2label = {int(k): v for k, v in id2label.items()} | |
config.id2label = id2label | |
config.label2id = {v: k for k, v in id2label.items()} | |
else: | |
config.use_relative_position_bias = False | |
config.use_shared_relative_position_bias = True | |
config.use_mean_pooling = False | |
if is_large: | |
config.hidden_size = 1024 | |
config.intermediate_size = 4096 | |
config.num_hidden_layers = 24 | |
config.num_attention_heads = 16 | |
# 2. Load Beit model | |
orig_model = load_beit_model(args, is_finetuned, is_large) | |
orig_model.eval() | |
# 3. Forward Beit model | |
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False) | |
image = Image.open("../../../../tests/fixtures/tests_samples/COCO/000000039769.png") | |
encoding = image_processor(images=image, return_tensors="pt") | |
pixel_values = encoding["pixel_values"] | |
orig_args = (pixel_values,) if is_finetuned else (pixel_values, None) | |
with torch.no_grad(): | |
orig_model_output = orig_model(*orig_args) | |
# 4. Load HF Data2VecVision model | |
if is_finetuned: | |
hf_model = Data2VecVisionForImageClassification(config) | |
hf_model.eval() | |
has_lm_head = False | |
hf_prefix = "data2vec_vision." | |
else: | |
hf_model = Data2VecVisionModel(config) | |
hf_model.eval() | |
has_lm_head = True | |
hf_prefix = "" | |
rename_keys = create_rename_keys(config, hf_prefix=hf_prefix, has_lm_head=has_lm_head) | |
state_dict = orig_model.state_dict() | |
for src, dest in rename_keys: | |
val = state_dict.pop(src) | |
state_dict[dest] = val | |
read_in_q_k_v(state_dict, config, hf_prefix=hf_prefix, has_lm_head=has_lm_head) | |
missing_keys, unexpected_keys = hf_model.load_state_dict(state_dict, strict=False) | |
print("HF missing", missing_keys) | |
print("HF unexpected_keys", unexpected_keys) | |
# 5. Forward HF Data2VecVision model | |
with torch.no_grad(): | |
hf_model_output = hf_model(pixel_values) | |
hf_output = hf_model_output.logits if is_finetuned else hf_model_output.last_hidden_state | |
# 6. Compare | |
max_absolute_diff = torch.max(torch.abs(hf_output - orig_model_output)).item() | |
print(f"max_absolute_diff = {max_absolute_diff}") | |
success = torch.allclose(hf_output, orig_model_output, atol=1e-3) | |
print("Do both models output the same tensors?", "π₯" if success else "π©") | |
if not success: | |
raise Exception("Something went wRoNg") | |
# 7. Save | |
print(f"Saving to {args.hf_checkpoint_name}") | |
hf_model.save_pretrained(args.hf_checkpoint_name) | |
image_processor.save_pretrained(args.hf_checkpoint_name) | |
if __name__ == "__main__": | |
main() | |
# Run the following to convert checkpoints | |
# python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \ | |
# --beit_checkpoint ./pretrained_base.pt \ | |
# --hf_checkpoint_name "./data2vec-vision-base" | |
# python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \ | |
# --beit_checkpoint ./finetuned_base.pt \ | |
# --hf_checkpoint_name "./data2vec-vision-base-ft1k" | |
# python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \ | |
# --beit_checkpoint ./pretrained_large.pt \ | |
# --hf_checkpoint_name "./data2vec-vision-large" | |
# python ./convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py \ | |
# --beit_checkpoint ./finetuned_large.pt \ | |
# --hf_checkpoint_name "./data2vec-vision-large-ft1k" | |