alessandro trinca tornidor commited on
Commit
c311b69
1 Parent(s): c92d24c

[feat] add support for image embedding re-use

Browse files
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -23,8 +23,8 @@ python = "~3.10"
23
  python-dotenv = "^1.0.1"
24
  rasterio = "^1.3.9"
25
  requests = "^2.31.0"
26
- lisa-on-cuda = "^1.0.4"
27
- samgis-core = "^1.0.7"
28
 
29
  [tool.poetry.group.aws_lambda]
30
  optional = true
 
23
  python-dotenv = "^1.0.1"
24
  rasterio = "^1.3.9"
25
  requests = "^2.31.0"
26
+ samgis-core = "^1.1.2"
27
+ lisa-on-cuda = "^1.1.1"
28
 
29
  [tool.poetry.group.aws_lambda]
30
  optional = true
samgis_lisa_on_cuda/io/coordinates_pixel_conversion.py CHANGED
@@ -1,5 +1,5 @@
1
  """functions useful to convert to/from latitude-longitude coordinates to pixel image coordinates"""
2
- from samgis_core.utilities.type_hints import tuple_float, tuple_float_any
3
 
4
  from samgis_lisa_on_cuda import app_logger
5
  from samgis_lisa_on_cuda.utilities.constants import TILE_SIZE, EARTH_EQUATORIAL_RADIUS
@@ -82,7 +82,7 @@ def get_latlng_to_pixel_coordinates(
82
  return point
83
 
84
 
85
- def _from4326_to3857(lat: float, lon: float) -> tuple_float or tuple_float_any:
86
  from math import radians, log, tan
87
 
88
  x_tile: float = radians(lon) * EARTH_EQUATORIAL_RADIUS
 
1
  """functions useful to convert to/from latitude-longitude coordinates to pixel image coordinates"""
2
+ from samgis_core.utilities.type_hints import TupleFloat, TupleFloatAny
3
 
4
  from samgis_lisa_on_cuda import app_logger
5
  from samgis_lisa_on_cuda.utilities.constants import TILE_SIZE, EARTH_EQUATORIAL_RADIUS
 
82
  return point
83
 
84
 
85
+ def _from4326_to3857(lat: float, lon: float) -> TupleFloat or TupleFloatAny:
86
  from math import radians, log, tan
87
 
88
  x_tile: float = radians(lon) * EARTH_EQUATORIAL_RADIUS
samgis_lisa_on_cuda/io/geo_helpers.py CHANGED
@@ -2,11 +2,11 @@
2
  from affine import Affine
3
  from numpy import ndarray as np_ndarray
4
 
5
- from samgis_core.utilities.type_hints import list_float, tuple_float, dict_str_int
6
  from samgis_lisa_on_cuda import app_logger
7
 
8
 
9
- def load_affine_transformation_from_matrix(matrix_source_coefficients: list_float) -> Affine:
10
  """
11
  Wrapper for rasterio.Affine.from_gdal() method
12
 
@@ -32,7 +32,7 @@ def load_affine_transformation_from_matrix(matrix_source_coefficients: list_floa
32
  raise e
33
 
34
 
35
- def get_affine_transform_from_gdal(matrix_source_coefficients: list_float or tuple_float) -> Affine:
36
  """wrapper for rasterio Affine from_gdal method
37
 
38
  Args:
@@ -44,7 +44,7 @@ def get_affine_transform_from_gdal(matrix_source_coefficients: list_float or tup
44
  return Affine.from_gdal(*matrix_source_coefficients)
45
 
46
 
47
- def get_vectorized_raster_as_geojson(mask: np_ndarray, transform: tuple_float) -> dict_str_int:
48
  """
49
  Get shapes and values of connected regions in a dataset or array
50
 
 
2
  from affine import Affine
3
  from numpy import ndarray as np_ndarray
4
 
5
+ from samgis_core.utilities.type_hints import ListFloat, TupleFloat, DictStrInt
6
  from samgis_lisa_on_cuda import app_logger
7
 
8
 
9
+ def load_affine_transformation_from_matrix(matrix_source_coefficients: ListFloat) -> Affine:
10
  """
11
  Wrapper for rasterio.Affine.from_gdal() method
12
 
 
32
  raise e
33
 
34
 
35
+ def get_affine_transform_from_gdal(matrix_source_coefficients: ListFloat or TupleFloat) -> Affine:
36
  """wrapper for rasterio Affine from_gdal method
37
 
38
  Args:
 
44
  return Affine.from_gdal(*matrix_source_coefficients)
45
 
46
 
47
+ def get_vectorized_raster_as_geojson(mask: np_ndarray, transform: TupleFloat) -> DictStrInt:
48
  """
49
  Get shapes and values of connected regions in a dataset or array
50
 
samgis_lisa_on_cuda/io/tms2geotiff.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
 
3
  from numpy import ndarray
4
- from samgis_core.utilities.type_hints import tuple_float
5
  from xyzservices import TileProvider
6
 
7
  from samgis_lisa_on_cuda import app_logger
@@ -70,7 +70,7 @@ def download_extent(w: float, s: float, e: float, n: float, zoom: int or str = z
70
  raise e_download_extent
71
 
72
 
73
- def crop_raster(w: float, s: float, e: float, n: float, raster: ndarray, raster_bbox: tuple_float,
74
  crs: str = OUTPUT_CRS_STRING, driver: str = DRIVER_RASTERIO_GTIFF) -> tuple_ndarray_transform:
75
  """
76
  Crop a raster using given bounding box (w, s, e, n) values
@@ -134,7 +134,7 @@ def crop_raster(w: float, s: float, e: float, n: float, raster: ndarray, raster_
134
  raise e_crop_raster
135
 
136
 
137
- def get_transform_raster(raster: ndarray, raster_bbox: tuple_float) -> tuple_ndarray_transform:
138
  """
139
  Convert the input raster image to RGB and extract the Affine
140
 
 
1
  import os
2
 
3
  from numpy import ndarray
4
+ from samgis_core.utilities.type_hints import TupleFloat
5
  from xyzservices import TileProvider
6
 
7
  from samgis_lisa_on_cuda import app_logger
 
70
  raise e_download_extent
71
 
72
 
73
+ def crop_raster(w: float, s: float, e: float, n: float, raster: ndarray, raster_bbox: TupleFloat,
74
  crs: str = OUTPUT_CRS_STRING, driver: str = DRIVER_RASTERIO_GTIFF) -> tuple_ndarray_transform:
75
  """
76
  Crop a raster using given bounding box (w, s, e, n) values
 
134
  raise e_crop_raster
135
 
136
 
137
+ def get_transform_raster(raster: ndarray, raster_bbox: TupleFloat) -> tuple_ndarray_transform:
138
  """
139
  Convert the input raster image to RGB and extract the Affine
140
 
samgis_lisa_on_cuda/io/wrappers_helpers.py CHANGED
@@ -238,3 +238,25 @@ def get_url_tile(source_type: str):
238
 
239
  def check_source_type_is_terrain(source: str | TileProvider):
240
  return isinstance(source, TileProvider) and source.name in list(XYZTerrainProvidersNames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  def check_source_type_is_terrain(source: str | TileProvider):
240
  return isinstance(source, TileProvider) and source.name in list(XYZTerrainProvidersNames)
241
+
242
+
243
+ def get_source_name(source: str | TileProvider) -> str | bool:
244
+ try:
245
+ match source.lower():
246
+ case XYZDefaultProvidersNames.DEFAULT_TILES_NAME_SHORT:
247
+ source_output = providers.query_name(XYZDefaultProvidersNames.DEFAULT_TILES_NAME)
248
+ case _:
249
+ source_output = providers.query_name(source)
250
+ if isinstance(source_output, str):
251
+ return source_output
252
+ try:
253
+ source_dict = dict(source_output)
254
+ app_logger.info(f"source_dict:{type(source_dict)}, {'name' in source_dict}, source_dict:{source_dict}.")
255
+ return source_dict["name"]
256
+ except KeyError as ke:
257
+ app_logger.error(f"ke:{ke}.")
258
+ except ValueError as ve:
259
+ app_logger.info(f"source name::{source}, ve:{ve}.")
260
+ app_logger.info(f"source name::{source}.")
261
+
262
+ return False
samgis_lisa_on_cuda/prediction_api/global_models.py CHANGED
@@ -2,3 +2,5 @@ models_dict = {
2
  "fastsam": {"instance": None},
3
  "lisa": {"inference": None}
4
  }
 
 
 
2
  "fastsam": {"instance": None},
3
  "lisa": {"inference": None}
4
  }
5
+ embedding_dict = {}
6
+
samgis_lisa_on_cuda/prediction_api/lisa.py CHANGED
@@ -1,7 +1,7 @@
1
  from datetime import datetime
2
 
3
  from lisa_on_cuda.utils import app_helpers
4
- from samgis_core.utilities.type_hints import llist_float, dict_str_int
5
  from samgis_lisa_on_cuda import app_logger
6
  from samgis_lisa_on_cuda.io.geo_helpers import get_vectorized_raster_as_geojson
7
  from samgis_lisa_on_cuda.io.raster_helpers import write_raster_png, write_raster_tiff
@@ -13,12 +13,13 @@ msg_write_tmp_on_disk = "found option to write images and geojson output..."
13
 
14
 
15
  def lisa_predict(
16
- bbox: llist_float,
17
  prompt: str,
18
  zoom: float,
19
  inference_function_name_key: str = "lisa",
20
- source: str = DEFAULT_URL_TILES
21
- ) -> dict_str_int:
 
22
  """
23
  Return predictions as a geojson from a geo-referenced image using the given input prompt.
24
 
@@ -33,6 +34,7 @@ def lisa_predict(
33
  zoom: Level of detail
34
  inference_function_name_key: machine learning model name
35
  source: xyz
 
36
 
37
  Returns:
38
  Affine transform
@@ -54,9 +56,9 @@ def lisa_predict(
54
  app_logger.info(
55
  f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
56
  folder_write_tmp_on_disk = getenv("WRITE_TMP_ON_DISK", "")
 
57
  if bool(folder_write_tmp_on_disk):
58
  now = datetime.now().strftime('%Y%m%d_%H%M%S')
59
- prefix = f"w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}_"
60
  app_logger.info(msg_write_tmp_on_disk + f"with coords {prefix}, shape:{img.shape}, {len(img.shape)}.")
61
  if img.shape and len(img.shape) == 2:
62
  write_raster_tiff(img, transform, f"{prefix}_{now}_", f"raw_tiff", folder_write_tmp_on_disk)
@@ -65,7 +67,9 @@ def lisa_predict(
65
  else:
66
  app_logger.info("keep all temp data in memory...")
67
 
68
- _, mask, output_string = inference_fn(prompt, img)
 
 
69
  # app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
70
  return {
71
  "output_string": output_string,
 
1
  from datetime import datetime
2
 
3
  from lisa_on_cuda.utils import app_helpers
4
+ from samgis_core.utilities.type_hints import LlistFloat, DictStrInt
5
  from samgis_lisa_on_cuda import app_logger
6
  from samgis_lisa_on_cuda.io.geo_helpers import get_vectorized_raster_as_geojson
7
  from samgis_lisa_on_cuda.io.raster_helpers import write_raster_png, write_raster_tiff
 
13
 
14
 
15
  def lisa_predict(
16
+ bbox: LlistFloat,
17
  prompt: str,
18
  zoom: float,
19
  inference_function_name_key: str = "lisa",
20
+ source: str = DEFAULT_URL_TILES,
21
+ source_name: str = None
22
+ ) -> DictStrInt:
23
  """
24
  Return predictions as a geojson from a geo-referenced image using the given input prompt.
25
 
 
34
  zoom: Level of detail
35
  inference_function_name_key: machine learning model name
36
  source: xyz
37
+ source_name: name of tile provider
38
 
39
  Returns:
40
  Affine transform
 
56
  app_logger.info(
57
  f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
58
  folder_write_tmp_on_disk = getenv("WRITE_TMP_ON_DISK", "")
59
+ prefix = f"w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}_"
60
  if bool(folder_write_tmp_on_disk):
61
  now = datetime.now().strftime('%Y%m%d_%H%M%S')
 
62
  app_logger.info(msg_write_tmp_on_disk + f"with coords {prefix}, shape:{img.shape}, {len(img.shape)}.")
63
  if img.shape and len(img.shape) == 2:
64
  write_raster_tiff(img, transform, f"{prefix}_{now}_", f"raw_tiff", folder_write_tmp_on_disk)
 
67
  else:
68
  app_logger.info("keep all temp data in memory...")
69
 
70
+ app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.")
71
+ embedding_key = f"{source_name}_z{zoom}_{prefix}"
72
+ _, mask, output_string = inference_fn(prompt, img, app_logger, embedding_key)
73
  # app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
74
  return {
75
  "output_string": output_string,
samgis_lisa_on_cuda/prediction_api/predictors.py CHANGED
@@ -4,21 +4,21 @@ from samgis_lisa_on_cuda.io.geo_helpers import get_vectorized_raster_as_geojson
4
  from samgis_lisa_on_cuda.io.raster_helpers import get_raster_terrain_rgb_like, get_rgb_prediction_image
5
  from samgis_lisa_on_cuda.io.tms2geotiff import download_extent
6
  from samgis_lisa_on_cuda.io.wrappers_helpers import check_source_type_is_terrain
7
- from samgis_lisa_on_cuda.prediction_api.global_models import models_dict
8
  from samgis_lisa_on_cuda.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
9
- from samgis_core.prediction_api.sam_onnx import SegmentAnythingONNX
10
- from samgis_core.prediction_api.sam_onnx import get_raster_inference
11
  from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
12
- from samgis_core.utilities.type_hints import llist_float, dict_str_int, list_dict
13
 
14
 
15
  def samexporter_predict(
16
- bbox: llist_float,
17
- prompt: list_dict,
18
  zoom: float,
19
  model_name_key: str = "fastsam",
20
- source: str = DEFAULT_URL_TILES
21
- ) -> dict_str_int:
 
22
  """
23
  Return predictions as a geojson from a geo-referenced image using the given input prompt.
24
 
@@ -33,6 +33,7 @@ def samexporter_predict(
33
  zoom: Level of detail
34
  model_name_key: machine learning model name
35
  source: xyz
 
36
 
37
  Returns:
38
  Affine transform
@@ -60,8 +61,10 @@ def samexporter_predict(
60
 
61
  app_logger.info(
62
  f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
63
-
64
- mask, n_predictions = get_raster_inference(img, prompt, models_instance, model_name_key)
 
 
65
  app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
66
  return {
67
  "n_predictions": n_predictions,
 
4
  from samgis_lisa_on_cuda.io.raster_helpers import get_raster_terrain_rgb_like, get_rgb_prediction_image
5
  from samgis_lisa_on_cuda.io.tms2geotiff import download_extent
6
  from samgis_lisa_on_cuda.io.wrappers_helpers import check_source_type_is_terrain
7
+ from samgis_lisa_on_cuda.prediction_api.global_models import models_dict, embedding_dict
8
  from samgis_lisa_on_cuda.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
9
+ from samgis_core.prediction_api.sam_onnx import SegmentAnythingONNX, get_raster_inference_with_embedding_from_dict
 
10
  from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
11
+ from samgis_core.utilities.type_hints import LlistFloat, DictStrInt, ListDict
12
 
13
 
14
  def samexporter_predict(
15
+ bbox: LlistFloat,
16
+ prompt: ListDict,
17
  zoom: float,
18
  model_name_key: str = "fastsam",
19
+ source: str = DEFAULT_URL_TILES,
20
+ source_name: str = None
21
+ ) -> DictStrInt:
22
  """
23
  Return predictions as a geojson from a geo-referenced image using the given input prompt.
24
 
 
33
  zoom: Level of detail
34
  model_name_key: machine learning model name
35
  source: xyz
36
+ source_name: name of tile provider
37
 
38
  Returns:
39
  Affine transform
 
61
 
62
  app_logger.info(
63
  f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
64
+ app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.")
65
+ embedding_key = f"{source_name}_z{zoom}_w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}"
66
+ mask, n_predictions = get_raster_inference_with_embedding_from_dict(
67
+ img, prompt, models_instance, model_name_key, embedding_key, embedding_dict)
68
  app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
69
  return {
70
  "n_predictions": n_predictions,
wrappers/fastapi_wrapper.py CHANGED
@@ -88,7 +88,7 @@ def post_test_string(request_input: StringPromptApiRequestBody) -> JSONResponse:
88
  @app.post("/infer_lisa")
89
  def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
90
  from samgis_lisa_on_cuda.prediction_api import lisa
91
- from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_bbox_points_with_string_prompt
92
 
93
  app_logger.info("starting lisa inference request...")
94
 
@@ -100,9 +100,11 @@ def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
100
  app_logger.info(f"lisa body_request:{body_request}.")
101
  app_logger.info(f"lisa module:{lisa}.")
102
  try:
 
 
103
  output = lisa.lisa_predict(
104
  bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
105
- source=body_request["source"]
106
  )
107
  duration_run = time.time() - time_start_run
108
  app_logger.info(f"duration_run:{duration_run}.")
@@ -132,7 +134,7 @@ def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
132
  @app.post("/infer_samgis")
133
  def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
134
  from samgis_lisa_on_cuda.prediction_api import predictors
135
- from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt
136
 
137
  app_logger.info("starting plain samgis inference request...")
138
 
@@ -143,9 +145,11 @@ def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
143
  body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
144
  app_logger.info(f"body_request:{body_request}.")
145
  try:
 
 
146
  output = predictors.samexporter_predict(
147
  bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
148
- source=body_request["source"]
149
  )
150
  duration_run = time.time() - time_start_run
151
  app_logger.info(f"duration_run:{duration_run}.")
@@ -208,7 +212,7 @@ async def lisa() -> FileResponse:
208
  return FileResponse(path=WORKDIR / "static" / "dist" / "lisa.html", media_type="text/html")
209
 
210
 
211
- app.mount("/", StaticFiles(directory=WORKDIR / "static" / "dist", html=True), name="static")
212
 
213
 
214
  @app.get("/")
 
88
  @app.post("/infer_lisa")
89
  def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
90
  from samgis_lisa_on_cuda.prediction_api import lisa
91
+ from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_bbox_points_with_string_prompt, get_source_name
92
 
93
  app_logger.info("starting lisa inference request...")
94
 
 
100
  app_logger.info(f"lisa body_request:{body_request}.")
101
  app_logger.info(f"lisa module:{lisa}.")
102
  try:
103
+ source_name = get_source_name(request_input.source_type)
104
+ app_logger.info(f"source_name = {source_name}.")
105
  output = lisa.lisa_predict(
106
  bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
107
+ source=body_request["source"], source_name=source_name
108
  )
109
  duration_run = time.time() - time_start_run
110
  app_logger.info(f"duration_run:{duration_run}.")
 
134
  @app.post("/infer_samgis")
135
  def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
136
  from samgis_lisa_on_cuda.prediction_api import predictors
137
+ from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt, get_source_name
138
 
139
  app_logger.info("starting plain samgis inference request...")
140
 
 
145
  body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
146
  app_logger.info(f"body_request:{body_request}.")
147
  try:
148
+ source_name = get_source_name(request_input.source_type)
149
+ app_logger.info(f"source_name = {source_name}.")
150
  output = predictors.samexporter_predict(
151
  bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
152
+ source=body_request["source"], source_name=source_name
153
  )
154
  duration_run = time.time() - time_start_run
155
  app_logger.info(f"duration_run:{duration_run}.")
 
212
  return FileResponse(path=WORKDIR / "static" / "dist" / "lisa.html", media_type="text/html")
213
 
214
 
215
+ app.mount("/", StaticFiles(directory=WORKDIR / "static" / "dist", html=True), name="root")
216
 
217
 
218
  @app.get("/")