import numpy.typing as npt import time # from .sam import build_sam, SamPredictor # from .sam_hq import build_sam as build_sam_hq, SamPredictor as SamHqPredictor # from .mobile_sam import ( # build_sam_vit_t as build_mobile_sam, # SamPredictor as MobileSamPredictor, # ) # from .per_sam import train, PerSAM # from .configs import DEVICE from app.sam import build_sam, SamPredictor from app.sam_hq import build_sam as build_sam_hq, SamPredictor as SamHqPredictor from app.mobile_sam import ( build_sam_vit_t as build_mobile_sam, SamPredictor as MobileSamPredictor, ) from app.per_sam import train, PerSAM from app.configs import DEVICE def build_sam_predictor(checkpoint: str | None = None): sam = build_sam(checkpoint) sam = sam.to(DEVICE) return SamPredictor(sam) def build_sam_hq_predictor(checkpoint: str | None = None): sam = build_sam_hq(checkpoint) sam = sam.to(DEVICE) return SamHqPredictor(sam) def build_mobile_sam_predictor(checkpoint: str | None = None): sam = build_mobile_sam(checkpoint) sam = sam.to(DEVICE) return MobileSamPredictor(sam) def get_multi_label_predictor( sam: MobileSamPredictor, image: npt.NDArray, mask: npt.NDArray, ) -> PerSAM: start = time.perf_counter() weights, target_feat = train(sam, [image], [mask]) print(f"training time {time.perf_counter() - start}") per_sam_model = PerSAM(sam, target_feat, 10, 0.4, 0.2, weights) return per_sam_model if __name__ == "__main__": import numpy as np from PIL import Image from torchvision.transforms.functional import resize from app.transforms import ResizeLongestSide T = ResizeLongestSide(1024) image = Image.open("/Users/dillonlaird/code/instance_labeler/seals-labeled/img2.png").convert("RGB") target_size = T.get_preprocess_shape(image.size[1], image.size[0], T.target_length) image_np = np.array(resize(image, target_size)) mask = Image.open("/Users/dillonlaird/code/instance_labeler/seals-labeled/img2.seal.10.png") target_size = T.get_preprocess_shape(mask.size[1], mask.size[0], T.target_length) mask_np = np.array(resize(mask, target_size).convert("L")) model = build_mobile_sam_predictor("/Users/dillonlaird/code/instance_labeler/mobile_sam.pth") start = time.perf_counter() per_sam_model = get_multi_label_predictor(model, image_np, mask_np) print(f"training time {time.perf_counter() - start}") start = time.perf_counter() masks, bboxes, _ = per_sam_model(image_np) print(f"prediction time {time.perf_counter() - start}")