|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Convert ALIGN checkpoints from the original repository.""" |
|
|
|
import argparse |
|
import os |
|
|
|
import align |
|
import numpy as np |
|
import requests |
|
import tensorflow as tf |
|
import torch |
|
from PIL import Image |
|
from tokenizer import Tokenizer |
|
|
|
from transformers import ( |
|
AlignConfig, |
|
AlignModel, |
|
AlignProcessor, |
|
BertConfig, |
|
BertTokenizer, |
|
EfficientNetConfig, |
|
EfficientNetImageProcessor, |
|
) |
|
from transformers.utils import logging |
|
|
|
|
|
logging.set_verbosity_info() |
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def preprocess(image): |
|
image = tf.image.resize(image, (346, 346)) |
|
image = tf.image.crop_to_bounding_box(image, (346 - 289) // 2, (346 - 289) // 2, 289, 289) |
|
return image |
|
|
|
|
|
def get_align_config(): |
|
vision_config = EfficientNetConfig.from_pretrained("google/efficientnet-b7") |
|
vision_config.image_size = 289 |
|
vision_config.hidden_dim = 640 |
|
vision_config.id2label = {"0": "LABEL_0", "1": "LABEL_1"} |
|
vision_config.label2id = {"LABEL_0": 0, "LABEL_1": 1} |
|
vision_config.depthwise_padding = [] |
|
|
|
text_config = BertConfig() |
|
config = AlignConfig.from_text_vision_configs( |
|
text_config=text_config, vision_config=vision_config, projection_dim=640 |
|
) |
|
return config |
|
|
|
|
|
|
|
def prepare_img(): |
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
im = Image.open(requests.get(url, stream=True).raw) |
|
return im |
|
|
|
|
|
def get_processor(): |
|
image_processor = EfficientNetImageProcessor( |
|
do_center_crop=True, |
|
rescale_factor=1 / 127.5, |
|
rescale_offset=True, |
|
do_normalize=False, |
|
include_top=False, |
|
resample=Image.BILINEAR, |
|
) |
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
tokenizer.model_max_length = 64 |
|
processor = AlignProcessor(image_processor=image_processor, tokenizer=tokenizer) |
|
return processor |
|
|
|
|
|
|
|
def rename_keys(original_param_names): |
|
|
|
block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")] |
|
block_names = list(set(block_names)) |
|
block_names = sorted(block_names) |
|
num_blocks = len(block_names) |
|
block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))} |
|
|
|
rename_keys = [] |
|
rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight")) |
|
rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight")) |
|
rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias")) |
|
rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean")) |
|
rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var")) |
|
|
|
for b in block_names: |
|
hf_b = block_name_mapping[b] |
|
rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight")) |
|
rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight")) |
|
rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias")) |
|
rename_keys.append( |
|
(f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean") |
|
) |
|
rename_keys.append( |
|
(f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var") |
|
) |
|
rename_keys.append( |
|
(f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight") |
|
) |
|
rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight")) |
|
rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias")) |
|
rename_keys.append( |
|
(f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean") |
|
) |
|
rename_keys.append( |
|
(f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var") |
|
) |
|
|
|
rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight")) |
|
rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias")) |
|
rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight")) |
|
rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias")) |
|
rename_keys.append( |
|
(f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight") |
|
) |
|
rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight")) |
|
rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias")) |
|
rename_keys.append( |
|
(f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean") |
|
) |
|
rename_keys.append( |
|
(f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var") |
|
) |
|
|
|
key_mapping = {} |
|
for item in rename_keys: |
|
if item[0] in original_param_names: |
|
key_mapping[item[0]] = "vision_model." + item[1] |
|
|
|
|
|
rename_keys = [] |
|
old = "tf_bert_model/bert" |
|
new = "text_model" |
|
for i in range(12): |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/attention/self/query/kernel:0", |
|
f"{new}.encoder.layer.{i}.attention.self.query.weight", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/attention/self/query/bias:0", |
|
f"{new}.encoder.layer.{i}.attention.self.query.bias", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/attention/self/key/kernel:0", |
|
f"{new}.encoder.layer.{i}.attention.self.key.weight", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/attention/self/key/bias:0", |
|
f"{new}.encoder.layer.{i}.attention.self.key.bias", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/attention/self/value/kernel:0", |
|
f"{new}.encoder.layer.{i}.attention.self.value.weight", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/attention/self/value/bias:0", |
|
f"{new}.encoder.layer.{i}.attention.self.value.bias", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/attention/output/dense/kernel:0", |
|
f"{new}.encoder.layer.{i}.attention.output.dense.weight", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/attention/output/dense/bias:0", |
|
f"{new}.encoder.layer.{i}.attention.output.dense.bias", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/gamma:0", |
|
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.weight", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/attention/output/LayerNorm/beta:0", |
|
f"{new}.encoder.layer.{i}.attention.output.LayerNorm.bias", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/intermediate/dense/kernel:0", |
|
f"{new}.encoder.layer.{i}.intermediate.dense.weight", |
|
) |
|
) |
|
rename_keys.append( |
|
( |
|
f"{old}/encoder/layer_._{i}/intermediate/dense/bias:0", |
|
f"{new}.encoder.layer.{i}.intermediate.dense.bias", |
|
) |
|
) |
|
rename_keys.append( |
|
(f"{old}/encoder/layer_._{i}/output/dense/kernel:0", f"{new}.encoder.layer.{i}.output.dense.weight") |
|
) |
|
rename_keys.append( |
|
(f"{old}/encoder/layer_._{i}/output/dense/bias:0", f"{new}.encoder.layer.{i}.output.dense.bias") |
|
) |
|
rename_keys.append( |
|
(f"{old}/encoder/layer_._{i}/output/LayerNorm/gamma:0", f"{new}.encoder.layer.{i}.output.LayerNorm.weight") |
|
) |
|
rename_keys.append( |
|
(f"{old}/encoder/layer_._{i}/output/LayerNorm/beta:0", f"{new}.encoder.layer.{i}.output.LayerNorm.bias") |
|
) |
|
|
|
rename_keys.append((f"{old}/embeddings/word_embeddings/weight:0", f"{new}.embeddings.word_embeddings.weight")) |
|
rename_keys.append( |
|
(f"{old}/embeddings/position_embeddings/embeddings:0", f"{new}.embeddings.position_embeddings.weight") |
|
) |
|
rename_keys.append( |
|
(f"{old}/embeddings/token_type_embeddings/embeddings:0", f"{new}.embeddings.token_type_embeddings.weight") |
|
) |
|
rename_keys.append((f"{old}/embeddings/LayerNorm/gamma:0", f"{new}.embeddings.LayerNorm.weight")) |
|
rename_keys.append((f"{old}/embeddings/LayerNorm/beta:0", f"{new}.embeddings.LayerNorm.bias")) |
|
|
|
rename_keys.append((f"{old}/pooler/dense/kernel:0", f"{new}.pooler.dense.weight")) |
|
rename_keys.append((f"{old}/pooler/dense/bias:0", f"{new}.pooler.dense.bias")) |
|
rename_keys.append(("dense/kernel:0", "text_projection.weight")) |
|
rename_keys.append(("dense/bias:0", "text_projection.bias")) |
|
rename_keys.append(("dense/bias:0", "text_projection.bias")) |
|
rename_keys.append(("temperature:0", "temperature")) |
|
|
|
for item in rename_keys: |
|
if item[0] in original_param_names: |
|
key_mapping[item[0]] = item[1] |
|
return key_mapping |
|
|
|
|
|
def replace_params(hf_params, tf_params, key_mapping): |
|
list(hf_params.keys()) |
|
|
|
for key, value in tf_params.items(): |
|
if key not in key_mapping: |
|
continue |
|
|
|
hf_key = key_mapping[key] |
|
if "_conv" in key and "kernel" in key: |
|
new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1) |
|
elif "embeddings" in key: |
|
new_hf_value = torch.from_numpy(value) |
|
elif "depthwise_kernel" in key: |
|
new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1) |
|
elif "kernel" in key: |
|
new_hf_value = torch.from_numpy(np.transpose(value)) |
|
elif "temperature" in key: |
|
new_hf_value = value |
|
elif "bn/gamma" or "bn/beta" in key: |
|
new_hf_value = torch.from_numpy(np.transpose(value)).squeeze() |
|
else: |
|
new_hf_value = torch.from_numpy(value) |
|
|
|
|
|
hf_params[hf_key].copy_(new_hf_value) |
|
|
|
|
|
@torch.no_grad() |
|
def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_model, push_to_hub): |
|
""" |
|
Copy/paste/tweak model's weights to our ALIGN structure. |
|
""" |
|
|
|
seq_length = 64 |
|
tok = Tokenizer(seq_length) |
|
original_model = align.Align("efficientnet-b7", "bert-base", 640, seq_length, tok.get_vocab_size()) |
|
original_model.compile() |
|
original_model.load_weights(checkpoint_path) |
|
|
|
tf_params = original_model.trainable_variables |
|
tf_non_train_params = original_model.non_trainable_variables |
|
tf_params = {param.name: param.numpy() for param in tf_params} |
|
for param in tf_non_train_params: |
|
tf_params[param.name] = param.numpy() |
|
tf_param_names = list(tf_params.keys()) |
|
|
|
|
|
config = get_align_config() |
|
hf_model = AlignModel(config).eval() |
|
hf_params = hf_model.state_dict() |
|
|
|
|
|
print("Converting parameters...") |
|
key_mapping = rename_keys(tf_param_names) |
|
replace_params(hf_params, tf_params, key_mapping) |
|
|
|
|
|
processor = get_processor() |
|
inputs = processor( |
|
images=prepare_img(), text="A picture of a cat", padding="max_length", max_length=64, return_tensors="pt" |
|
) |
|
|
|
|
|
hf_model.eval() |
|
with torch.no_grad(): |
|
outputs = hf_model(**inputs) |
|
|
|
hf_image_features = outputs.image_embeds.detach().numpy() |
|
hf_text_features = outputs.text_embeds.detach().numpy() |
|
|
|
|
|
original_model.trainable = False |
|
tf_image_processor = EfficientNetImageProcessor( |
|
do_center_crop=True, |
|
do_rescale=False, |
|
do_normalize=False, |
|
include_top=False, |
|
resample=Image.BILINEAR, |
|
) |
|
image = tf_image_processor(images=prepare_img(), return_tensors="tf", data_format="channels_last")["pixel_values"] |
|
text = tok(tf.constant(["A picture of a cat"])) |
|
|
|
image_features = original_model.image_encoder(image, training=False) |
|
text_features = original_model.text_encoder(text, training=False) |
|
|
|
image_features = tf.nn.l2_normalize(image_features, axis=-1) |
|
text_features = tf.nn.l2_normalize(text_features, axis=-1) |
|
|
|
|
|
if not np.allclose(image_features, hf_image_features, atol=1e-3): |
|
raise ValueError("The predicted image features are not the same.") |
|
if not np.allclose(text_features, hf_text_features, atol=1e-3): |
|
raise ValueError("The predicted text features are not the same.") |
|
print("Model outputs match!") |
|
|
|
if save_model: |
|
|
|
if not os.path.isdir(pytorch_dump_folder_path): |
|
os.mkdir(pytorch_dump_folder_path) |
|
|
|
hf_model.save_pretrained(pytorch_dump_folder_path) |
|
processor.save_pretrained(pytorch_dump_folder_path) |
|
|
|
if push_to_hub: |
|
|
|
print("Pushing converted ALIGN to the hub...") |
|
processor.push_to_hub("align-base") |
|
hf_model.push_to_hub("align-base") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--checkpoint_path", |
|
default="./weights/model-weights", |
|
type=str, |
|
help="Path to the pretrained TF ALIGN checkpoint.", |
|
) |
|
parser.add_argument( |
|
"--pytorch_dump_folder_path", |
|
default="hf_model", |
|
type=str, |
|
help="Path to the output PyTorch model directory.", |
|
) |
|
parser.add_argument("--save_model", action="store_true", help="Save model to local") |
|
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub") |
|
|
|
args = parser.parse_args() |
|
convert_align_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub) |
|
|