jhj0517 commited on
Commit
8d52a7d
1 Parent(s): f822d17

Divide model / predictors

Browse files
Files changed (1) hide show
  1. modules/sam_inference.py +15 -7
modules/sam_inference.py CHANGED
@@ -71,14 +71,19 @@ class SamInference:
71
  ckpt_path=model_path,
72
  device=self.device
73
  )
74
- self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
75
- self.mask_generator = SAM2AutomaticMaskGenerator(
76
- model=self.model,
77
- **self.maskgen_hparams
78
- )
79
  except Exception as e:
80
  print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
81
 
 
 
 
 
 
 
 
 
 
 
82
  def generate_mask(self,
83
  image: np.ndarray):
84
  return self.mask_generator.generate(image)
@@ -104,11 +109,14 @@ class SamInference:
104
  output_file_name = f"result-{timestamp}.psd"
105
  output_path = os.path.join(self.output_dir, "psd", output_file_name)
106
 
107
- if self.model is None or self.mask_generator is None or self.model_type != model_type or self.maskgen_hparams != maskgen_hparams:
108
  self.model_type = model_type
109
- self.maskgen_hparams = maskgen_hparams
110
  self.load_model()
111
 
 
 
 
 
112
  masks = self.mask_generator.generate(image)
113
 
114
  save_psd_with_masks(image, masks, output_path)
 
71
  ckpt_path=model_path,
72
  device=self.device
73
  )
 
 
 
 
 
74
  except Exception as e:
75
  print(f"Layer Divider Extension : Error while Loading SAM2 model! {e}")
76
 
77
+ def set_predictors(self):
78
+ if self.model is None:
79
+ self.load_model()
80
+
81
+ self.image_predictor = SAM2ImagePredictor(sam_model=self.model)
82
+ self.mask_generator = SAM2AutomaticMaskGenerator(
83
+ model=self.model,
84
+ **self.maskgen_hparams
85
+ )
86
+
87
  def generate_mask(self,
88
  image: np.ndarray):
89
  return self.mask_generator.generate(image)
 
109
  output_file_name = f"result-{timestamp}.psd"
110
  output_path = os.path.join(self.output_dir, "psd", output_file_name)
111
 
112
+ if self.model is None or self.model_type != model_type:
113
  self.model_type = model_type
 
114
  self.load_model()
115
 
116
+ if self.mask_generator is None or self.maskgen_hparams != maskgen_hparams:
117
+ self.maskgen_hparams = maskgen_hparams
118
+ self.set_predictors()
119
+
120
  masks = self.mask_generator.generate(image)
121
 
122
  save_psd_with_masks(image, masks, output_path)