File size: 774 Bytes
010b878
cf5fa6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import importlib, torch, cv2, numpy as np
from huggingface_hub import hf_hub_download
from segment_anything import sam_model_registry, SamPredictor
from groundingdino.util.inference import Model as GDModel

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

@torch.inference_mode()
def load_sam():
    ckpt = hf_hub_download("ybelkada/segment-anything", "sam_vit_b.pth")
    sam = sam_model_registry["vit_b"](checkpoint=ckpt)
    return SamPredictor(sam.to(DEVICE))

@torch.inference_mode()
def load_groundingdino():
    ckpt = hf_hub_download(
        "GroundingDINO/groundingdino-swint-ogc",
        "groundingdino_swint_ogc.pth"
    )
    return GDModel(model_config_path=None, model_checkpoint_path=ckpt, device=DEVICE)

SAM = load_sam()
GD  = load_groundingdino()