| |
| """Convert HF release safetensors checkpoints into official repo checkpoint layouts.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from argparse import Namespace |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import load_file |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Convert a safetensors release checkpoint into the PyTorch checkpoint " |
| "layout expected by the official RT-DETRv4 or RF-DETR repositories." |
| ) |
| ) |
| parser.add_argument( |
| "--framework", |
| choices=("rtdetrv4", "rfdetr"), |
| required=True, |
| help="Target official repository format.", |
| ) |
| parser.add_argument( |
| "--input", |
| type=Path, |
| required=True, |
| help="Input .safetensors checkpoint path.", |
| ) |
| parser.add_argument( |
| "--output", |
| type=Path, |
| required=True, |
| help="Output .pth checkpoint path.", |
| ) |
| parser.add_argument( |
| "--class-names", |
| nargs="+", |
| default=["person", "head"], |
| help="Class names to store in RF-DETR checkpoint metadata.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| state_dict = load_file(str(args.input)) |
| args.output.parent.mkdir(parents=True, exist_ok=True) |
|
|
| if args.framework == "rtdetrv4": |
| payload = {"model": state_dict} |
| else: |
| payload = { |
| "model": state_dict, |
| "args": Namespace(class_names=args.class_names), |
| } |
|
|
| torch.save(payload, args.output) |
|
|
| print(f"Converted {args.input} -> {args.output}") |
| print(f"Framework: {args.framework}") |
| print(f"Tensors: {len(state_dict)}") |
| if args.framework == "rfdetr": |
| print(f"Class names: {args.class_names}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|