#!/usr/bin/env python3 """Extract the model backbone from the checkpoint.""" import torch from torchgeo.models import dofa_base_patch16_224 # Load the checkpoint in_filename = "ofa_base_checkpoint_e99.pth" checkpoint = torch.load(in_filename, map_location=torch.device("cpu")) # Remove extra keys weights = checkpoint["model"] del weights["mask_token"] del weights["norm.weight"], weights["norm.bias"] del weights["projector.weight"], weights["projector.bias"] # Load the weights to ensure they are valid # fc_norm and head are generated dynamically allowed_missing_keys = {"fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"} model = dofa_base_patch16_224() missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) assert set(missing_keys) <= allowed_missing_keys assert not unexpected_keys # Save the cleaned checkpoint # Should be manually renamed later, add first 8 digits of sha256 to suffix out_filename = "dofa_base_patch16_224.pth" torch.save(weights, out_filename)