Spaces:
Runtime error
Runtime error
File size: 1,958 Bytes
ad5354d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
from src.efficientvit.models.efficientvit import (EfficientViTSam,
efficientvit_sam_l0,
efficientvit_sam_l1,
efficientvit_sam_l2,
efficientvit_sam_xl0,
efficientvit_sam_xl1)
from src.efficientvit.models.nn.norm import set_norm_eps
from src.efficientvit.models.utils import load_state_dict_from_file
__all__ = ["create_sam_model"]
REGISTERED_SAM_MODEL: dict[str, str] = {
"l0": "assets/checkpoints/sam/l0.pt",
"l1": "assets/checkpoints/sam/l1.pt",
"l2": "assets/checkpoints/sam/l2.pt",
"xl0": "assets/checkpoints/sam/xl0.pt",
"xl1": "assets/checkpoints/sam/xl1.pt",
}
def create_sam_model(
name: str, pretrained=True, weight_url: str or None = None, **kwargs
) -> EfficientViTSam:
model_dict = {
"l0": efficientvit_sam_l0,
"l1": efficientvit_sam_l1,
"l2": efficientvit_sam_l2,
"xl0": efficientvit_sam_xl0,
"xl1": efficientvit_sam_xl1,
}
model_id = name.split("-")[0]
if model_id not in model_dict:
raise ValueError(
f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}"
)
else:
model = model_dict[model_id](**kwargs)
set_norm_eps(model, 1e-6)
if pretrained:
weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None)
if weight_url is None:
raise ValueError(f"Do not find the pretrained weight of {name}.")
else:
weight = load_state_dict_from_file(weight_url)
model.load_state_dict(weight)
return model
|