File size: 1,359 Bytes
9bc4638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse

import torch


def main(args):
    sd = torch.load(args.src, map_location="cpu")["model"]
    sd = {k: v for k, v in sd.items() if "teacher" not in k}
    sd = {
        k.replace("backbone.vision_backbone", "image_encoder"): v for k, v in sd.items()
    }
    sd = {k.replace("mlp.fc1", "mlp.layers.0"): v for k, v in sd.items()}
    sd = {k.replace("mlp.fc2", "mlp.layers.1"): v for k, v in sd.items()}
    sd = {k.replace("convs", "neck.convs"): v for k, v in sd.items()}
    sd = {
        k.replace("transformer.encoder", "memory_attention"): v for k, v in sd.items()
    }
    sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
    sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
    sd = {k.replace("mlp.lin1", "mlp.layers.0"): v for k, v in sd.items()}
    sd = {k.replace("mlp.lin2", "mlp.layers.1"): v for k, v in sd.items()}
    torch.save({"model": sd}, args.src.replace(".pt", "_converted.pt"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--src", type=str, required=True)
    args = parser.parse_args()

    main(args)