File size: 2,698 Bytes
11d7b39
 
e4defb0
11d7b39
 
 
 
 
 
 
 
 
e4defb0
11d7b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a34578b
 
 
 
 
 
 
607e627
a34578b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import os
import torch

from modules.mask_utils import *
from modules.model_downloader import *


class SamInference:
    def __init__(self):
        self.model = None
        self.model_path = f"models/sam_vit_h_4b8939.pth"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.mask_generator = None

        # Tuable Parameters , All default values
        self.tunable_params = {
            'points_per_side': 32,
            'pred_iou_thresh': 0.88,
            'stability_score_thresh': 0.95,
            'crop_n_layers': 0,
            'crop_n_points_downscale_factor': 1,
            'min_mask_region_area': 0
        }

    def set_mask_generator(self):
        print("applying configs to model..")
        if not os.path.exists(self.model_path):
            print("No needed SAM model detected. downloading VIT H SAM model....")
            download_sam_model_url()

        self.model = sam_model_registry["default"](checkpoint=self.model_path)
        self.model.to(device=self.device)
        self.mask_generator = SamAutomaticMaskGenerator(
            self.model,
            points_per_side=self.tunable_params['points_per_side'],
            pred_iou_thresh=self.tunable_params['pred_iou_thresh'],
            stability_score_thresh=self.tunable_params['stability_score_thresh'],
            crop_n_layers=self.tunable_params['crop_n_layers'],
            crop_n_points_downscale_factor=self.tunable_params['crop_n_points_downscale_factor'],
            min_mask_region_area=self.tunable_params['min_mask_region_area'],
            output_mode="coco_rle",
        )

    def generate_mask(self, image):
        return [self.mask_generator.generate(image)]

    def generate_mask_app(self, image, *params):
        tunable_params = {
            'points_per_side': int(params[0]),
            'pred_iou_thresh': float(params[1]),
            'stability_score_thresh': float(params[2]),
            'crop_n_layers': int(params[3]),
            'crop_n_points_downscale_factor': int(params[4]),
            'min_mask_region_area': int(params[5]),
        }

        try:
            if self.model is None or self.mask_generator is None or self.tunable_params != tunable_params:
                self.tunable_params = tunable_params
                self.set_mask_generator()
            masks = self.mask_generator.generate(image)
            combined_image = create_mask_combined_images(image, masks)
            gallery = create_mask_gallery(image, masks)
            return [combined_image] + gallery
        except Exception as e:
            print(e)