# coding=utf-8 # Copyright 2022 The IDEA Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------------------------------ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # ------------------------------------------------------------------------------------------------ # Modified from: # https://github.com/facebookresearch/detr/blob/main/d2/converter.py # ------------------------------------------------------------------------------------------------ import argparse import numpy as np import torch def parse_args(): parser = argparse.ArgumentParser("detrex model converter") parser.add_argument( "--source_model", default="", type=str, help="Path or url to the DETR model to convert" ) parser.add_argument( "--output_model", default="", type=str, help="Path where to save the converted model" ) return parser.parse_args() def main(): args = parse_args() # fmt: off coco_idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] # fmt: on coco_idx = np.array(coco_idx) if args.source_model.startswith("https"): checkpoint = torch.hub.load_state_dict_from_url( args.source_model, map_location="cpu", check_hash=True ) else: checkpoint = torch.load(args.source_model, map_location="cpu") model_to_convert = checkpoint["model"] model_converted = {} for k in model_to_convert.keys(): old_k = k if "backbone" in k: k = k.replace("backbone.0.body.", "") if "layer" not in k: k = "stem." + k for t in [1, 2, 3, 4]: k = k.replace(f"layer{t}", f"res{t + 1}") for t in [1, 2, 3]: k = k.replace(f"bn{t}", f"conv{t}.norm") k = k.replace("downsample.0", "shortcut") k = k.replace("downsample.1", "shortcut.norm") k = "backbone." + k # add new convert content if "decoder" in k: if "decoder.norm" in k: k = k.replace("decoder.norm", "decoder.post_norm_layer") if "ca_kcontent_proj" in k: k = k.replace("ca_kcontent_proj", "attentions.1.key_content_proj") elif "ca_kpos_proj" in k: k = k.replace("ca_kpos_proj", "attentions.1.key_pos_proj") elif "ca_qcontent_proj" in k: k = k.replace("ca_qcontent_proj", "attentions.1.query_content_proj") elif "ca_qpos_proj" in k: k = k.replace("ca_qpos_proj", "attentions.1.query_pos_proj") elif "ca_qpos_sine_proj" in k: k = k.replace("ca_qpos_sine_proj", "attentions.1.query_pos_sine_proj") elif "ca_v_proj" in k: k = k.replace("ca_v_proj", "attentions.1.value_proj") elif "sa_kcontent_proj" in k: k = k.replace("sa_kcontent_proj", "attentions.0.key_content_proj") elif "sa_kpos_proj" in k: k = k.replace("sa_kpos_proj", "attentions.0.key_pos_proj") elif "sa_qcontent_proj" in k: k = k.replace("sa_qcontent_proj", "attentions.0.query_content_proj") elif "sa_qpos_proj" in k: k = k.replace("sa_qpos_proj", "attentions.0.query_pos_proj") elif "sa_v_proj" in k: k = k.replace("sa_v_proj", "attentions.0.value_proj") elif "self_attn.out_proj" in k: k = k.replace("self_attn.out_proj", "attentions.0.out_proj") elif "cross_attn.out_proj" in k: k = k.replace("cross_attn.out_proj", "attentions.1.out_proj") elif "linear1" in k: k = k.replace("linear1", "ffns.0.layers.0.0") elif "linear2" in k: k = k.replace("linear2", "ffns.0.layers.1") elif "norm1" in k: k = k.replace("norm1", "norms.0") elif "norm2" in k: k = k.replace("norm2", "norms.1") elif "norm3" in k: k = k.replace("norm3", "norms.2") elif "activation" in k: k = k.replace("activation", "ffns.0.layers.0.1") if "encoder" in k: if "self_attn" in k: k = k.replace("self_attn", "attentions.0.attn") if "linear1" in k: k = k.replace("linear1", "ffns.0.layers.0.0") elif "linear2" in k: k = k.replace("linear2", "ffns.0.layers.1") elif "norm1" in k: k = k.replace("norm1", "norms.0") elif "norm2" in k: k = k.replace("norm2", "norms.1") elif "activation" in k: k = k.replace("activation", "ffns.0.layers.0.1") print(old_k, "->", k) if "class_embed" in old_k: v = model_to_convert[old_k].detach() if v.shape[0] == 91: shape_old = v.shape model_converted[k] = v[coco_idx] print( "Head conversion: changing shape from {} to {}".format( shape_old, model_converted[k].shape ) ) continue model_converted[k] = model_to_convert[old_k].detach() model_to_save = {"model": model_converted} torch.save(model_to_save, args.output_model) if __name__ == "__main__": main()