File size: 2,573 Bytes
6723494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80c2e04
6723494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy.typing as npt
import time

# from .sam import build_sam, SamPredictor
# from .sam_hq import build_sam as build_sam_hq, SamPredictor as SamHqPredictor
# from .mobile_sam import (
#     build_sam_vit_t as build_mobile_sam,
#     SamPredictor as MobileSamPredictor,
# )
# from .per_sam import train, PerSAM
# from .configs import DEVICE

from app.sam import build_sam, SamPredictor
from app.sam_hq import build_sam as build_sam_hq, SamPredictor as SamHqPredictor
from app.mobile_sam import (
    build_sam_vit_t as build_mobile_sam,
    SamPredictor as MobileSamPredictor,
)
from app.per_sam import train, PerSAM
from app.configs import DEVICE


def build_sam_predictor(checkpoint: str | None = None):
    sam = build_sam(checkpoint)
    sam = sam.to(DEVICE)
    return SamPredictor(sam)


def build_sam_hq_predictor(checkpoint: str | None = None):
    sam = build_sam_hq(checkpoint)
    sam = sam.to(DEVICE)
    return SamHqPredictor(sam)


def build_mobile_sam_predictor(checkpoint: str | None = None):
    sam = build_mobile_sam(checkpoint)
    sam = sam.to(DEVICE)
    return MobileSamPredictor(sam)


def get_multi_label_predictor(
    sam: MobileSamPredictor, image: npt.NDArray, mask: npt.NDArray,
) -> PerSAM:
    start = time.perf_counter()
    weights, target_feat = train(sam, [image], [mask])
    print(f"training time {time.perf_counter() - start}")
    per_sam_model = PerSAM(sam, target_feat, 10, 0.4, 0.2, weights)
    return per_sam_model


if __name__ == "__main__":
    import numpy as np
    from PIL import Image
    from torchvision.transforms.functional import resize
    from app.transforms import ResizeLongestSide
    T = ResizeLongestSide(1024)
    image = Image.open("/Users/dillonlaird/code/instance_labeler/seals-labeled/img2.png").convert("RGB")
    target_size = T.get_preprocess_shape(image.size[1], image.size[0], T.target_length)
    image_np = np.array(resize(image, target_size))
    mask = Image.open("/Users/dillonlaird/code/instance_labeler/seals-labeled/img2.seal.10.png")
    target_size = T.get_preprocess_shape(mask.size[1], mask.size[0], T.target_length)
    mask_np = np.array(resize(mask, target_size).convert("L"))

    model = build_mobile_sam_predictor("/Users/dillonlaird/code/instance_labeler/mobile_sam.pth")
    start = time.perf_counter()
    per_sam_model = get_multi_label_predictor(model, image_np, mask_np)
    print(f"training time {time.perf_counter() - start}")
    start = time.perf_counter()
    masks, bboxes, _ = per_sam_model(image_np)
    print(f"prediction time {time.perf_counter() - start}")