Spaces:
Runtime error
Runtime error
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
from sam2.build_sam import build_sam2 | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
import torch | |
import os | |
from datetime import datetime | |
import numpy as np | |
from modules.model_downloader import ( | |
AVAILABLE_MODELS, DEFAULT_MODEL_TYPE, OUTPUT_DIR, | |
is_sam_exist, | |
download_sam_model_url | |
) | |
from modules.paths import SAM2_CONFIGS_DIR, MODELS_DIR | |
from modules.mask_utils import ( | |
save_psd_with_masks, | |
create_mask_combined_images, | |
create_mask_gallery | |
) | |
CONFIGS = { | |
"sam2_hiera_tiny": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_t.yaml"), | |
"sam2_hiera_small": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_s.yaml"), | |
"sam2_hiera_base_plus": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_b+.yaml"), | |
"sam2_hiera_large": os.path.join(SAM2_CONFIGS_DIR, "sam2_hiera_l.yaml"), | |
} | |
class SamInference: | |
def __init__(self, | |
model_dir: str = MODELS_DIR, | |
output_dir: str = OUTPUT_DIR | |
): | |
self.model = None | |
self.available_models = list(AVAILABLE_MODELS.keys()) | |
self.model_type = DEFAULT_MODEL_TYPE | |
self.model_dir = model_dir | |
self.output_dir = output_dir | |
self.model_path = os.path.join(self.model_dir, AVAILABLE_MODELS[DEFAULT_MODEL_TYPE][0]) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.mask_generator = None | |
self.image_predictor = None | |
# Tunable Parameters , All default values by https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/automatic_mask_generator_example.ipynb | |
self.maskgen_hparams = { | |
"points_per_side": 64, | |
"points_per_batch": 128, | |
"pred_iou_thresh": 0.7, | |
"stability_score_thresh": 0.92, | |
"stability_score_offset": 0.7, | |
"crop_n_layers": 1, | |
"box_nms_thresh": 0.7, | |
"crop_n_points_downscale_factor": 2, | |
"min_mask_region_area": 25.0, | |
"use_m2m": True, | |
} | |
def load_model(self): | |
config = CONFIGS[self.model_type] | |
filename, url = AVAILABLE_MODELS[self.model_type] | |
model_path = os.path.join(self.model_dir, filename) | |
if not is_sam_exist(self.model_type): | |
print(f"\nLayer Divider Extension : No SAM2 model found, downloading {self.model_type} model...") | |
download_sam_model_url(self.model_type) | |
print("\nLayer Divider Extension : applying configs to model..") | |
try: | |
self.model = build_sam2( | |
config_file=config, | |
ckpt_path=model_path, | |
device=self.device | |
) | |
except Exception as e: | |
print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}") | |
def set_predictors(self): | |
if self.model is None: | |
self.load_model() | |
self.image_predictor = SAM2ImagePredictor(sam_model=self.model) | |
self.mask_generator = SAM2AutomaticMaskGenerator( | |
model=self.model, | |
**self.maskgen_hparams | |
) | |
def generate_mask(self, | |
image: np.ndarray): | |
return self.mask_generator.generate(image) | |
def generate_mask_app(self, | |
image: np.ndarray, | |
model_type: str, | |
*params | |
): | |
maskgen_hparams = { | |
'points_per_side': int(params[0]), | |
'points_per_batch': int(params[1]), | |
'pred_iou_thresh': float(params[2]), | |
'stability_score_thresh': float(params[3]), | |
'stability_score_offset': float(params[4]), | |
'crop_n_layers': int(params[5]), | |
'box_nms_thresh': float(params[6]), | |
'crop_n_points_downscale_factor': int(params[7]), | |
'min_mask_region_area': int(params[8]), | |
'use_m2m': bool(params[9]) | |
} | |
timestamp = datetime.now().strftime("%m%d%H%M%S") | |
output_file_name = f"result-{timestamp}.psd" | |
output_path = os.path.join(self.output_dir, "psd", output_file_name) | |
if self.model is None or self.model_type != model_type: | |
self.model_type = model_type | |
self.load_model() | |
if self.mask_generator is None or self.maskgen_hparams != maskgen_hparams: | |
self.maskgen_hparams = maskgen_hparams | |
self.set_predictors() | |
masks = self.mask_generator.generate(image) | |
save_psd_with_masks(image, masks, output_path) | |
combined_image = create_mask_combined_images(image, masks) | |
gallery = create_mask_gallery(image, masks) | |
return [combined_image] + gallery, output_path | |