alessandro trinca tornidor commited on
Commit
d1e187e
·
1 Parent(s): f2e3747

bug: fix wrong import for samgis inference function

Browse files
samgis_lisa_on_zero/prediction_api/predictors.py CHANGED
@@ -6,7 +6,7 @@ from samgis_lisa_on_zero.io.tms2geotiff import download_extent
6
  from samgis_lisa_on_zero.io.wrappers_helpers import check_source_type_is_terrain
7
  from samgis_lisa_on_zero.prediction_api.global_models import models_dict, embedding_dict
8
  from samgis_lisa_on_zero.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
9
- from samgis_core.prediction_api.sam_onnx import SegmentAnythingONNX, get_raster_inference_with_embedding_from_dict
10
  from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
11
  from samgis_core.utilities.type_hints import LlistFloat, DictStrInt, ListDict
12
 
@@ -40,7 +40,7 @@ def samexporter_predict(
40
  """
41
  if models_dict[model_name_key]["instance"] is None:
42
  app_logger.info(f"missing instance model {model_name_key}, instantiating it now!")
43
- model_instance = SegmentAnythingONNX(
44
  encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
45
  decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
46
  )
@@ -63,7 +63,7 @@ def samexporter_predict(
63
  f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
64
  app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.")
65
  embedding_key = f"{source_name}_z{zoom}_w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}"
66
- mask, n_predictions = get_raster_inference_with_embedding_from_dict(
67
  img, prompt, models_instance, model_name_key, embedding_key, embedding_dict)
68
  app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
69
  return {
 
6
  from samgis_lisa_on_zero.io.wrappers_helpers import check_source_type_is_terrain
7
  from samgis_lisa_on_zero.prediction_api.global_models import models_dict, embedding_dict
8
  from samgis_lisa_on_zero.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
9
+ from samgis_core.prediction_api import sam_onnx2, sam_onnx_inference
10
  from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
11
  from samgis_core.utilities.type_hints import LlistFloat, DictStrInt, ListDict
12
 
 
40
  """
41
  if models_dict[model_name_key]["instance"] is None:
42
  app_logger.info(f"missing instance model {model_name_key}, instantiating it now!")
43
+ model_instance = sam_onnx2.SegmentAnythingONNX2(
44
  encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
45
  decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
46
  )
 
63
  f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
64
  app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.")
65
  embedding_key = f"{source_name}_z{zoom}_w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}"
66
+ mask, n_predictions = sam_onnx_inference.get_raster_inference_with_embedding_from_dict(
67
  img, prompt, models_instance, model_name_key, embedding_key, embedding_dict)
68
  app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
69
  return {