winnie22 / transformers_4_35_0 /models /cvt /convert_cvt_original_pytorch_checkpoint_to_pytorch.py
mart9992's picture
m
06ba6ce
raw
history blame
No virus
13.6 kB
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert CvT checkpoints from the original repository.
URL: https://github.com/microsoft/CvT"""
import argparse
import json
from collections import OrderedDict
import torch
from huggingface_hub import cached_download, hf_hub_url
from transformers import AutoImageProcessor, CvtConfig, CvtForImageClassification
def embeddings(idx):
"""
The function helps in renaming embedding layer weights.
Args:
idx: stage number in original model
"""
embed = []
embed.append(
(
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.weight",
f"stage{idx}.patch_embed.proj.weight",
)
)
embed.append(
(
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.projection.bias",
f"stage{idx}.patch_embed.proj.bias",
)
)
embed.append(
(
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.weight",
f"stage{idx}.patch_embed.norm.weight",
)
)
embed.append(
(
f"cvt.encoder.stages.{idx}.embedding.convolution_embeddings.normalization.bias",
f"stage{idx}.patch_embed.norm.bias",
)
)
return embed
def attention(idx, cnt):
"""
The function helps in renaming attention block layers weights.
Args:
idx: stage number in original model
cnt: count of blocks in each stage
"""
attention_weights = []
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.convolution.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.conv.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.bias",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_mean",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_mean",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.running_var",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.running_var",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_query.convolution_projection.normalization.num_batches_tracked",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_q.bn.num_batches_tracked",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.convolution.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.conv.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.bias",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_mean",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_mean",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.running_var",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.running_var",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_key.convolution_projection.normalization.num_batches_tracked",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_k.bn.num_batches_tracked",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.convolution.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.conv.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.weight",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.bias",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_mean",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_mean",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.running_var",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.running_var",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.convolution_projection_value.convolution_projection.normalization.num_batches_tracked",
f"stage{idx}.blocks.{cnt}.attn.conv_proj_v.bn.num_batches_tracked",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.weight",
f"stage{idx}.blocks.{cnt}.attn.proj_q.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_query.bias",
f"stage{idx}.blocks.{cnt}.attn.proj_q.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.weight",
f"stage{idx}.blocks.{cnt}.attn.proj_k.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_key.bias",
f"stage{idx}.blocks.{cnt}.attn.proj_k.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.weight",
f"stage{idx}.blocks.{cnt}.attn.proj_v.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.attention.projection_value.bias",
f"stage{idx}.blocks.{cnt}.attn.proj_v.bias",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.weight",
f"stage{idx}.blocks.{cnt}.attn.proj.weight",
)
)
attention_weights.append(
(
f"cvt.encoder.stages.{idx}.layers.{cnt}.attention.output.dense.bias",
f"stage{idx}.blocks.{cnt}.attn.proj.bias",
)
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc1.weight")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.intermediate.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc1.bias")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.weight", f"stage{idx}.blocks.{cnt}.mlp.fc2.weight")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.output.dense.bias", f"stage{idx}.blocks.{cnt}.mlp.fc2.bias")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.weight", f"stage{idx}.blocks.{cnt}.norm1.weight")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_before.bias", f"stage{idx}.blocks.{cnt}.norm1.bias")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.weight", f"stage{idx}.blocks.{cnt}.norm2.weight")
)
attention_weights.append(
(f"cvt.encoder.stages.{idx}.layers.{cnt}.layernorm_after.bias", f"stage{idx}.blocks.{cnt}.norm2.bias")
)
return attention_weights
def cls_token(idx):
"""
Function helps in renaming cls_token weights
"""
token = []
token.append((f"cvt.encoder.stages.{idx}.cls_token", "stage2.cls_token"))
return token
def final():
"""
Function helps in renaming final classification layer
"""
head = []
head.append(("layernorm.weight", "norm.weight"))
head.append(("layernorm.bias", "norm.bias"))
head.append(("classifier.weight", "head.weight"))
head.append(("classifier.bias", "head.bias"))
return head
def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_folder):
"""
Fucntion to convert the microsoft cvt checkpoint to huggingface checkpoint
"""
img_labels_file = "imagenet-1k-id2label.json"
num_labels = 1000
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(cached_download(hf_hub_url(repo_id, img_labels_file, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label
label2id = {v: k for k, v in id2label.items()}
config = config = CvtConfig(num_labels=num_labels, id2label=id2label, label2id=label2id)
# For depth size 13 (13 = 1+2+10)
if cvt_model.rsplit("/", 1)[-1][4:6] == "13":
config.depth = [1, 2, 10]
# For depth size 21 (21 = 1+4+16)
elif cvt_model.rsplit("/", 1)[-1][4:6] == "21":
config.depth = [1, 4, 16]
# For wide cvt (similar to wide-resnet) depth size 24 (w24 = 2 + 2 20)
else:
config.depth = [2, 2, 20]
config.num_heads = [3, 12, 16]
config.embed_dim = [192, 768, 1024]
model = CvtForImageClassification(config)
image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k")
image_processor.size["shortest_edge"] = image_size
original_weights = torch.load(cvt_file_name, map_location=torch.device("cpu"))
huggingface_weights = OrderedDict()
list_of_state_dict = []
for idx in range(len(config.depth)):
if config.cls_token[idx]:
list_of_state_dict = list_of_state_dict + cls_token(idx)
list_of_state_dict = list_of_state_dict + embeddings(idx)
for cnt in range(config.depth[idx]):
list_of_state_dict = list_of_state_dict + attention(idx, cnt)
list_of_state_dict = list_of_state_dict + final()
for gg in list_of_state_dict:
print(gg)
for i in range(len(list_of_state_dict)):
huggingface_weights[list_of_state_dict[i][0]] = original_weights[list_of_state_dict[i][1]]
model.load_state_dict(huggingface_weights)
model.save_pretrained(pytorch_dump_folder)
image_processor.save_pretrained(pytorch_dump_folder)
# Download the weights from zoo: https://1drv.ms/u/s!AhIXJn_J-blW9RzF3rMW7SsLHa8h?e=blQ0Al
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cvt_model",
default="cvt-w24",
type=str,
help="Name of the cvt model you'd like to convert.",
)
parser.add_argument(
"--image_size",
default=384,
type=int,
help="Input Image Size",
)
parser.add_argument(
"--cvt_file_name",
default=r"cvtmodels\CvT-w24-384x384-IN-22k.pth",
type=str,
help="Input Image Size",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
args = parser.parse_args()
convert_cvt_checkpoint(args.cvt_model, args.image_size, args.cvt_file_name, args.pytorch_dump_folder_path)