Spaces:
Sleeping
Sleeping
Update src/ai_processor.py
Browse files- src/ai_processor.py +4 -55
src/ai_processor.py
CHANGED
|
@@ -210,61 +210,10 @@ def load_yolo_model():
|
|
| 210 |
model = YOLO(YOLO_MODEL_PATH)
|
| 211 |
return model
|
| 212 |
|
| 213 |
-
def load_segmentation_model():
|
| 214 |
-
"""
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
- Tries plain load
|
| 218 |
-
- Falls back to patched InputLayer (drops batch_shape args)
|
| 219 |
-
- Tries SavedModel/.keras fallback
|
| 220 |
-
"""
|
| 221 |
-
import tensorflow as tf
|
| 222 |
-
tf.config.set_visible_devices([], "GPU")
|
| 223 |
-
|
| 224 |
-
def _plain_load():
|
| 225 |
-
return tf.keras.models.load_model(SEG_MODEL_PATH, compile=False)
|
| 226 |
-
|
| 227 |
-
def _patched_load():
|
| 228 |
-
from tensorflow.keras.layers import InputLayer as KInputLayer
|
| 229 |
-
|
| 230 |
-
def _InputLayerPatched(*args, **kwargs):
|
| 231 |
-
kwargs.pop("batch_shape", None)
|
| 232 |
-
kwargs.pop("batch_input_shape", None)
|
| 233 |
-
return KInputLayer(**kwargs)
|
| 234 |
-
|
| 235 |
-
return tf.keras.models.load_model(
|
| 236 |
-
SEG_MODEL_PATH,
|
| 237 |
-
compile=False,
|
| 238 |
-
custom_objects={"InputLayer": _InputLayerPatched},
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
# 1) Try plain
|
| 242 |
-
try:
|
| 243 |
-
m = _plain_load()
|
| 244 |
-
logging.info("✅ Segmentation model loaded (plain).")
|
| 245 |
-
return m
|
| 246 |
-
except Exception as e1:
|
| 247 |
-
logging.warning(f"Plain load failed: {e1}")
|
| 248 |
-
|
| 249 |
-
# 2) Try patched InputLayer
|
| 250 |
-
try:
|
| 251 |
-
m = _patched_load()
|
| 252 |
-
logging.info("✅ Segmentation model loaded (patched InputLayer).")
|
| 253 |
-
return m
|
| 254 |
-
except Exception as e2:
|
| 255 |
-
logging.warning(f"Patched load failed: {e2}")
|
| 256 |
-
|
| 257 |
-
# 3) Try SavedModel/.keras folder fallback (strip extension)
|
| 258 |
-
try:
|
| 259 |
-
base, _ = os.path.splitext(SEG_MODEL_PATH)
|
| 260 |
-
m = tf.keras.models.load_model(base, compile=False)
|
| 261 |
-
logging.info("✅ Segmentation model loaded (SavedModel/.keras fallback).")
|
| 262 |
-
return m
|
| 263 |
-
except Exception as e3:
|
| 264 |
-
raise RuntimeError(f"Segmentation model could not be loaded: {e3}")
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
|
| 269 |
def load_classification_pipeline():
|
| 270 |
pipe = _import_hf_cls()
|
|
|
|
| 210 |
model = YOLO(YOLO_MODEL_PATH)
|
| 211 |
return model
|
| 212 |
|
| 213 |
+
def load_segmentation_model(seg_model_path):
|
| 214 |
+
"""Lazy import and load segmentation model."""
|
| 215 |
+
from tensorflow.keras.models import load_model
|
| 216 |
+
return load_model(seg_model_path, compile=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
def load_classification_pipeline():
|
| 219 |
pipe = _import_hf_cls()
|