Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,359 Bytes
72780d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
def main(args):
sd = torch.load(args.src, map_location="cpu")["model"]
sd = {k: v for k, v in sd.items() if "teacher" not in k}
sd = {
k.replace("backbone.vision_backbone", "image_encoder"): v for k, v in sd.items()
}
sd = {k.replace("mlp.fc1", "mlp.layers.0"): v for k, v in sd.items()}
sd = {k.replace("mlp.fc2", "mlp.layers.1"): v for k, v in sd.items()}
sd = {k.replace("convs", "neck.convs"): v for k, v in sd.items()}
sd = {
k.replace("transformer.encoder", "memory_attention"): v for k, v in sd.items()
}
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
sd = {k.replace("maskmem_backbone", "memory_encoder"): v for k, v in sd.items()}
sd = {k.replace("mlp.lin1", "mlp.layers.0"): v for k, v in sd.items()}
sd = {k.replace("mlp.lin2", "mlp.layers.1"): v for k, v in sd.items()}
torch.save({"model": sd}, args.src.replace(".pt", "_converted.pt"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, required=True)
args = parser.parse_args()
main(args)
|