auto-labeler / app /model.py
dillonlaird's picture
changed hyperparameters for mobile sam
80c2e04
raw
history blame contribute delete
No virus
2.57 kB
import numpy.typing as npt
import time
# from .sam import build_sam, SamPredictor
# from .sam_hq import build_sam as build_sam_hq, SamPredictor as SamHqPredictor
# from .mobile_sam import (
# build_sam_vit_t as build_mobile_sam,
# SamPredictor as MobileSamPredictor,
# )
# from .per_sam import train, PerSAM
# from .configs import DEVICE
from app.sam import build_sam, SamPredictor
from app.sam_hq import build_sam as build_sam_hq, SamPredictor as SamHqPredictor
from app.mobile_sam import (
build_sam_vit_t as build_mobile_sam,
SamPredictor as MobileSamPredictor,
)
from app.per_sam import train, PerSAM
from app.configs import DEVICE
def build_sam_predictor(checkpoint: str | None = None):
sam = build_sam(checkpoint)
sam = sam.to(DEVICE)
return SamPredictor(sam)
def build_sam_hq_predictor(checkpoint: str | None = None):
sam = build_sam_hq(checkpoint)
sam = sam.to(DEVICE)
return SamHqPredictor(sam)
def build_mobile_sam_predictor(checkpoint: str | None = None):
sam = build_mobile_sam(checkpoint)
sam = sam.to(DEVICE)
return MobileSamPredictor(sam)
def get_multi_label_predictor(
sam: MobileSamPredictor, image: npt.NDArray, mask: npt.NDArray,
) -> PerSAM:
start = time.perf_counter()
weights, target_feat = train(sam, [image], [mask])
print(f"training time {time.perf_counter() - start}")
per_sam_model = PerSAM(sam, target_feat, 10, 0.4, 0.2, weights)
return per_sam_model
if __name__ == "__main__":
import numpy as np
from PIL import Image
from torchvision.transforms.functional import resize
from app.transforms import ResizeLongestSide
T = ResizeLongestSide(1024)
image = Image.open("/Users/dillonlaird/code/instance_labeler/seals-labeled/img2.png").convert("RGB")
target_size = T.get_preprocess_shape(image.size[1], image.size[0], T.target_length)
image_np = np.array(resize(image, target_size))
mask = Image.open("/Users/dillonlaird/code/instance_labeler/seals-labeled/img2.seal.10.png")
target_size = T.get_preprocess_shape(mask.size[1], mask.size[0], T.target_length)
mask_np = np.array(resize(mask, target_size).convert("L"))
model = build_mobile_sam_predictor("/Users/dillonlaird/code/instance_labeler/mobile_sam.pth")
start = time.perf_counter()
per_sam_model = get_multi_label_predictor(model, image_np, mask_np)
print(f"training time {time.perf_counter() - start}")
start = time.perf_counter()
masks, bboxes, _ = per_sam_model(image_np)
print(f"prediction time {time.perf_counter() - start}")