|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from typing import Any, List, Optional, Tuple |
|
|
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
|
|
from dinov2.models import build_model_from_cfg |
|
from dinov2.utils.config import setup |
|
import dinov2.utils.utils as dinov2_utils |
|
|
|
|
|
def get_args_parser( |
|
description: Optional[str] = None, |
|
parents: Optional[List[argparse.ArgumentParser]] = None, |
|
add_help: bool = True, |
|
): |
|
parser = argparse.ArgumentParser( |
|
description=description, |
|
parents=parents or [], |
|
add_help=add_help, |
|
) |
|
parser.add_argument( |
|
"--config-file", |
|
type=str, |
|
help="Model configuration file", |
|
) |
|
parser.add_argument( |
|
"--pretrained-weights", |
|
type=str, |
|
help="Pretrained model weights", |
|
) |
|
parser.add_argument( |
|
"--output-dir", |
|
default="", |
|
type=str, |
|
help="Output directory to write results and logs", |
|
) |
|
parser.add_argument( |
|
"--opts", |
|
help="Extra configuration options", |
|
default=[], |
|
nargs="+", |
|
) |
|
return parser |
|
|
|
|
|
def get_autocast_dtype(config): |
|
teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype |
|
if teacher_dtype_str == "fp16": |
|
return torch.half |
|
elif teacher_dtype_str == "bf16": |
|
return torch.bfloat16 |
|
else: |
|
return torch.float |
|
|
|
|
|
def build_model_for_eval(config, pretrained_weights): |
|
model, _ = build_model_from_cfg(config, only_teacher=True) |
|
dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") |
|
model.eval() |
|
model.cuda() |
|
return model |
|
|
|
|
|
def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: |
|
cudnn.benchmark = True |
|
config = setup(args) |
|
model = build_model_for_eval(config, args.pretrained_weights) |
|
autocast_dtype = get_autocast_dtype(config) |
|
return model, autocast_dtype |
|
|