|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Convert BiT checkpoints from the timm library.""" |
|
|
|
|
|
import argparse |
|
import json |
|
from pathlib import Path |
|
|
|
import requests |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
from timm import create_model |
|
from timm.data import resolve_data_config |
|
from timm.data.transforms_factory import create_transform |
|
|
|
from transformers import BitConfig, BitForImageClassification, BitImageProcessor |
|
from transformers.image_utils import PILImageResampling |
|
from transformers.utils import logging |
|
|
|
|
|
logging.set_verbosity_info() |
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def get_config(model_name): |
|
repo_id = "huggingface/label-files" |
|
filename = "imagenet-1k-id2label.json" |
|
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) |
|
id2label = {int(k): v for k, v in id2label.items()} |
|
label2id = {v: k for k, v in id2label.items()} |
|
|
|
conv_layer = "std_conv" if "bit" in model_name else False |
|
|
|
|
|
|
|
|
|
config = BitConfig( |
|
conv_layer=conv_layer, |
|
num_labels=1000, |
|
id2label=id2label, |
|
label2id=label2id, |
|
) |
|
|
|
return config |
|
|
|
|
|
def rename_key(name): |
|
if "stem.conv" in name: |
|
name = name.replace("stem.conv", "bit.embedder.convolution") |
|
if "blocks" in name: |
|
name = name.replace("blocks", "layers") |
|
if "head.fc" in name: |
|
name = name.replace("head.fc", "classifier.1") |
|
if name.startswith("norm"): |
|
name = "bit." + name |
|
if "bit" not in name and "classifier" not in name: |
|
name = "bit.encoder." + name |
|
|
|
return name |
|
|
|
|
|
|
|
def prepare_img(): |
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
im = Image.open(requests.get(url, stream=True).raw) |
|
return im |
|
|
|
|
|
@torch.no_grad() |
|
def convert_bit_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False): |
|
""" |
|
Copy/paste/tweak model's weights to our BiT structure. |
|
""" |
|
|
|
|
|
config = get_config(model_name) |
|
|
|
|
|
timm_model = create_model(model_name, pretrained=True) |
|
timm_model.eval() |
|
|
|
|
|
state_dict = timm_model.state_dict() |
|
for key in state_dict.copy().keys(): |
|
val = state_dict.pop(key) |
|
state_dict[rename_key(key)] = val.squeeze() if "head" in key else val |
|
|
|
|
|
model = BitForImageClassification(config) |
|
model.eval() |
|
model.load_state_dict(state_dict) |
|
|
|
|
|
transform = create_transform(**resolve_data_config({}, model=timm_model)) |
|
timm_transforms = transform.transforms |
|
|
|
pillow_resamplings = { |
|
"bilinear": PILImageResampling.BILINEAR, |
|
"bicubic": PILImageResampling.BICUBIC, |
|
"nearest": PILImageResampling.NEAREST, |
|
} |
|
|
|
processor = BitImageProcessor( |
|
do_resize=True, |
|
size={"shortest_edge": timm_transforms[0].size}, |
|
resample=pillow_resamplings[timm_transforms[0].interpolation.value], |
|
do_center_crop=True, |
|
crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]}, |
|
do_normalize=True, |
|
image_mean=timm_transforms[-1].mean.tolist(), |
|
image_std=timm_transforms[-1].std.tolist(), |
|
) |
|
|
|
image = prepare_img() |
|
timm_pixel_values = transform(image).unsqueeze(0) |
|
pixel_values = processor(image, return_tensors="pt").pixel_values |
|
|
|
|
|
assert torch.allclose(timm_pixel_values, pixel_values) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(pixel_values) |
|
logits = outputs.logits |
|
|
|
print("Logits:", logits[0, :3]) |
|
print("Predicted class:", model.config.id2label[logits.argmax(-1).item()]) |
|
timm_logits = timm_model(pixel_values) |
|
assert timm_logits.shape == outputs.logits.shape |
|
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) |
|
print("Looks ok!") |
|
|
|
if pytorch_dump_folder_path is not None: |
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True) |
|
print(f"Saving model {model_name} and processor to {pytorch_dump_folder_path}") |
|
model.save_pretrained(pytorch_dump_folder_path) |
|
processor.save_pretrained(pytorch_dump_folder_path) |
|
|
|
if push_to_hub: |
|
print(f"Pushing model {model_name} and processor to the hub") |
|
model.push_to_hub(f"ybelkada/{model_name}") |
|
processor.push_to_hub(f"ybelkada/{model_name}") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--model_name", |
|
default="resnetv2_50x1_bitm", |
|
type=str, |
|
help="Name of the BiT timm model you'd like to convert.", |
|
) |
|
parser.add_argument( |
|
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." |
|
) |
|
parser.add_argument( |
|
"--push_to_hub", |
|
action="store_true", |
|
help="Whether to push the model to the hub.", |
|
) |
|
|
|
args = parser.parse_args() |
|
convert_bit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) |
|
|