Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
•
8d52a7d
1
Parent(s):
f822d17
Divide model / predictors
Browse files- modules/sam_inference.py +15 -7
modules/sam_inference.py
CHANGED
@@ -71,14 +71,19 @@ class SamInference:
|
|
71 |
ckpt_path=model_path,
|
72 |
device=self.device
|
73 |
)
|
74 |
-
self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
|
75 |
-
self.mask_generator = SAM2AutomaticMaskGenerator(
|
76 |
-
model=self.model,
|
77 |
-
**self.maskgen_hparams
|
78 |
-
)
|
79 |
except Exception as e:
|
80 |
print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def generate_mask(self,
|
83 |
image: np.ndarray):
|
84 |
return self.mask_generator.generate(image)
|
@@ -104,11 +109,14 @@ class SamInference:
|
|
104 |
output_file_name = f"result-{timestamp}.psd"
|
105 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
106 |
|
107 |
-
if self.model is None or self.
|
108 |
self.model_type = model_type
|
109 |
-
self.maskgen_hparams = maskgen_hparams
|
110 |
self.load_model()
|
111 |
|
|
|
|
|
|
|
|
|
112 |
masks = self.mask_generator.generate(image)
|
113 |
|
114 |
save_psd_with_masks(image, masks, output_path)
|
|
|
71 |
ckpt_path=model_path,
|
72 |
device=self.device
|
73 |
)
|
|
|
|
|
|
|
|
|
|
|
74 |
except Exception as e:
|
75 |
print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
|
76 |
|
77 |
+
def set_predictors(self):
|
78 |
+
if self.model is None:
|
79 |
+
self.load_model()
|
80 |
+
|
81 |
+
self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
|
82 |
+
self.mask_generator = SAM2AutomaticMaskGenerator(
|
83 |
+
model=self.model,
|
84 |
+
**self.maskgen_hparams
|
85 |
+
)
|
86 |
+
|
87 |
def generate_mask(self,
|
88 |
image: np.ndarray):
|
89 |
return self.mask_generator.generate(image)
|
|
|
109 |
output_file_name = f"result-{timestamp}.psd"
|
110 |
output_path = os.path.join(self.output_dir, "psd", output_file_name)
|
111 |
|
112 |
+
if self.model is None or self.model_type != model_type:
|
113 |
self.model_type = model_type
|
|
|
114 |
self.load_model()
|
115 |
|
116 |
+
if self.mask_generator is None or self.maskgen_hparams != maskgen_hparams:
|
117 |
+
self.maskgen_hparams = maskgen_hparams
|
118 |
+
self.set_predictors()
|
119 |
+
|
120 |
masks = self.mask_generator.generate(image)
|
121 |
|
122 |
save_psd_with_masks(image, masks, output_path)
|