Spaces:
Running
Running
# coding=utf-8 | |
# Copyright 2022 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Convert ConvNext checkpoints from the original repository. | |
URL: https://github.com/facebookresearch/ConvNeXt""" | |
import argparse | |
import json | |
from pathlib import Path | |
import requests | |
import torch | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
from transformers import ConvNextConfig, ConvNextForImageClassification, ConvNextImageProcessor | |
from transformers.utils import logging | |
logging.set_verbosity_info() | |
logger = logging.get_logger(__name__) | |
def get_convnext_config(checkpoint_url): | |
config = ConvNextConfig() | |
if "tiny" in checkpoint_url: | |
depths = [3, 3, 9, 3] | |
hidden_sizes = [96, 192, 384, 768] | |
if "small" in checkpoint_url: | |
depths = [3, 3, 27, 3] | |
hidden_sizes = [96, 192, 384, 768] | |
if "base" in checkpoint_url: | |
depths = [3, 3, 27, 3] | |
hidden_sizes = [128, 256, 512, 1024] | |
if "large" in checkpoint_url: | |
depths = [3, 3, 27, 3] | |
hidden_sizes = [192, 384, 768, 1536] | |
if "xlarge" in checkpoint_url: | |
depths = [3, 3, 27, 3] | |
hidden_sizes = [256, 512, 1024, 2048] | |
if "1k" in checkpoint_url: | |
num_labels = 1000 | |
filename = "imagenet-1k-id2label.json" | |
expected_shape = (1, 1000) | |
else: | |
num_labels = 21841 | |
filename = "imagenet-22k-id2label.json" | |
expected_shape = (1, 21841) | |
repo_id = "huggingface/label-files" | |
config.num_labels = num_labels | |
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) | |
id2label = {int(k): v for k, v in id2label.items()} | |
if "1k" not in checkpoint_url: | |
# this dataset contains 21843 labels but the model only has 21841 | |
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18 | |
del id2label[9205] | |
del id2label[15027] | |
config.id2label = id2label | |
config.label2id = {v: k for k, v in id2label.items()} | |
config.hidden_sizes = hidden_sizes | |
config.depths = depths | |
return config, expected_shape | |
def rename_key(name): | |
if "downsample_layers.0.0" in name: | |
name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings") | |
if "downsample_layers.0.1" in name: | |
name = name.replace("downsample_layers.0.1", "embeddings.norm") # we rename to layernorm later on | |
if "downsample_layers.1.0" in name: | |
name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0") | |
if "downsample_layers.1.1" in name: | |
name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1") | |
if "downsample_layers.2.0" in name: | |
name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0") | |
if "downsample_layers.2.1" in name: | |
name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1") | |
if "downsample_layers.3.0" in name: | |
name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0") | |
if "downsample_layers.3.1" in name: | |
name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1") | |
if "stages" in name and "downsampling_layer" not in name: | |
# stages.0.0. for instance should be renamed to stages.0.layers.0. | |
name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :] | |
if "stages" in name: | |
name = name.replace("stages", "encoder.stages") | |
if "norm" in name: | |
name = name.replace("norm", "layernorm") | |
if "gamma" in name: | |
name = name.replace("gamma", "layer_scale_parameter") | |
if "head" in name: | |
name = name.replace("head", "classifier") | |
return name | |
# We will verify our results on an image of cute cats | |
def prepare_img(): | |
url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
im = Image.open(requests.get(url, stream=True).raw) | |
return im | |
def convert_convnext_checkpoint(checkpoint_url, pytorch_dump_folder_path): | |
""" | |
Copy/paste/tweak model's weights to our ConvNext structure. | |
""" | |
# define ConvNext configuration based on URL | |
config, expected_shape = get_convnext_config(checkpoint_url) | |
# load original state_dict from URL | |
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"] | |
# rename keys | |
for key in state_dict.copy().keys(): | |
val = state_dict.pop(key) | |
state_dict[rename_key(key)] = val | |
# add prefix to all keys expect classifier head | |
for key in state_dict.copy().keys(): | |
val = state_dict.pop(key) | |
if not key.startswith("classifier"): | |
key = "convnext." + key | |
state_dict[key] = val | |
# load HuggingFace model | |
model = ConvNextForImageClassification(config) | |
model.load_state_dict(state_dict) | |
model.eval() | |
# Check outputs on an image, prepared by ConvNextImageProcessor | |
size = 224 if "224" in checkpoint_url else 384 | |
image_processor = ConvNextImageProcessor(size=size) | |
pixel_values = image_processor(images=prepare_img(), return_tensors="pt").pixel_values | |
logits = model(pixel_values).logits | |
# note: the logits below were obtained without center cropping | |
if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth": | |
expected_logits = torch.tensor([-0.1210, -0.6605, 0.1918]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth": | |
expected_logits = torch.tensor([-0.4473, -0.1847, -0.6365]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth": | |
expected_logits = torch.tensor([0.4525, 0.7539, 0.0308]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_384.pth": | |
expected_logits = torch.tensor([0.3561, 0.6350, -0.0384]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth": | |
expected_logits = torch.tensor([0.4174, -0.0989, 0.1489]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_384.pth": | |
expected_logits = torch.tensor([0.2513, -0.1349, -0.1613]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth": | |
expected_logits = torch.tensor([1.2980, 0.3631, -0.1198]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth": | |
expected_logits = torch.tensor([1.2963, 0.1227, 0.1723]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth": | |
expected_logits = torch.tensor([1.7956, 0.8390, 0.2820]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth": | |
expected_logits = torch.tensor([-0.2822, -0.0502, -0.0878]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth": | |
expected_logits = torch.tensor([-0.5672, -0.0730, -0.4348]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth": | |
expected_logits = torch.tensor([0.2681, 0.2365, 0.6246]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth": | |
expected_logits = torch.tensor([-0.2642, 0.3931, 0.5116]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth": | |
expected_logits = torch.tensor([-0.6677, -0.1873, -0.8379]) | |
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth": | |
expected_logits = torch.tensor([-0.7749, -0.2967, -0.6444]) | |
else: | |
raise ValueError(f"Unknown URL: {checkpoint_url}") | |
assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3) | |
assert logits.shape == expected_shape | |
Path(pytorch_dump_folder_path).mkdir(exist_ok=True) | |
print(f"Saving model to {pytorch_dump_folder_path}") | |
model.save_pretrained(pytorch_dump_folder_path) | |
print(f"Saving image processor to {pytorch_dump_folder_path}") | |
image_processor.save_pretrained(pytorch_dump_folder_path) | |
print("Pushing model to the hub...") | |
model_name = "convnext" | |
if "tiny" in checkpoint_url: | |
model_name += "-tiny" | |
elif "small" in checkpoint_url: | |
model_name += "-small" | |
elif "base" in checkpoint_url: | |
model_name += "-base" | |
elif "xlarge" in checkpoint_url: | |
model_name += "-xlarge" | |
elif "large" in checkpoint_url: | |
model_name += "-large" | |
if "224" in checkpoint_url: | |
model_name += "-224" | |
elif "384" in checkpoint_url: | |
model_name += "-384" | |
if "22k" in checkpoint_url and "1k" not in checkpoint_url: | |
model_name += "-22k" | |
if "22k" in checkpoint_url and "1k" in checkpoint_url: | |
model_name += "-22k-1k" | |
model.push_to_hub( | |
repo_path_or_name=Path(pytorch_dump_folder_path, model_name), | |
organization="nielsr", | |
commit_message="Add model", | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# Required parameters | |
parser.add_argument( | |
"--checkpoint_url", | |
default="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", | |
type=str, | |
help="URL of the original ConvNeXT checkpoint you'd like to convert.", | |
) | |
parser.add_argument( | |
"--pytorch_dump_folder_path", | |
default=None, | |
type=str, | |
required=True, | |
help="Path to the output PyTorch model directory.", | |
) | |
args = parser.parse_args() | |
convert_convnext_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path) | |