SmartHeal commited on
Commit
e4e240e
·
verified ·
1 Parent(s): 2857ee5

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +4 -55
src/ai_processor.py CHANGED
@@ -210,61 +210,10 @@ def load_yolo_model():
210
  model = YOLO(YOLO_MODEL_PATH)
211
  return model
212
 
213
- def load_segmentation_model():
214
- """
215
- Robust TF/Keras loader that:
216
- - Hides GPUs
217
- - Tries plain load
218
- - Falls back to patched InputLayer (drops batch_shape args)
219
- - Tries SavedModel/.keras fallback
220
- """
221
- import tensorflow as tf
222
- tf.config.set_visible_devices([], "GPU")
223
-
224
- def _plain_load():
225
- return tf.keras.models.load_model(SEG_MODEL_PATH, compile=False)
226
-
227
- def _patched_load():
228
- from tensorflow.keras.layers import InputLayer as KInputLayer
229
-
230
- def _InputLayerPatched(*args, **kwargs):
231
- kwargs.pop("batch_shape", None)
232
- kwargs.pop("batch_input_shape", None)
233
- return KInputLayer(**kwargs)
234
-
235
- return tf.keras.models.load_model(
236
- SEG_MODEL_PATH,
237
- compile=False,
238
- custom_objects={"InputLayer": _InputLayerPatched},
239
- )
240
-
241
- # 1) Try plain
242
- try:
243
- m = _plain_load()
244
- logging.info("✅ Segmentation model loaded (plain).")
245
- return m
246
- except Exception as e1:
247
- logging.warning(f"Plain load failed: {e1}")
248
-
249
- # 2) Try patched InputLayer
250
- try:
251
- m = _patched_load()
252
- logging.info("✅ Segmentation model loaded (patched InputLayer).")
253
- return m
254
- except Exception as e2:
255
- logging.warning(f"Patched load failed: {e2}")
256
-
257
- # 3) Try SavedModel/.keras folder fallback (strip extension)
258
- try:
259
- base, _ = os.path.splitext(SEG_MODEL_PATH)
260
- m = tf.keras.models.load_model(base, compile=False)
261
- logging.info("✅ Segmentation model loaded (SavedModel/.keras fallback).")
262
- return m
263
- except Exception as e3:
264
- raise RuntimeError(f"Segmentation model could not be loaded: {e3}")
265
-
266
-
267
-
268
 
269
  def load_classification_pipeline():
270
  pipe = _import_hf_cls()
 
210
  model = YOLO(YOLO_MODEL_PATH)
211
  return model
212
 
213
+ def load_segmentation_model(seg_model_path):
214
+ """Lazy import and load segmentation model."""
215
+ from tensorflow.keras.models import load_model
216
+ return load_model(seg_model_path, compile=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  def load_classification_pipeline():
219
  pipe = _import_hf_cls()