aletrn commited on
Commit
54cbf5f
·
1 Parent(s): aa5bd92

[fix] handle source types with additional parameters other than x,y,z; set default source_type = OpenStreetMap.Mapnik

Browse files
src/app.py CHANGED
@@ -39,7 +39,7 @@ def lambda_handler(event: Dict, context: LambdaContext) -> str:
39
 
40
  try:
41
  body_response = samexporter_predict(
42
- body_request["bbox"], body_request["prompt"], body_request["zoom"], url_tile=body_request["url_tile"]
43
  )
44
  app_logger.info(f"output body_response length:{len(body_response)}.")
45
  app_logger.debug(f"output body_response:{body_response}.")
 
39
 
40
  try:
41
  body_response = samexporter_predict(
42
+ body_request["bbox"], body_request["prompt"], body_request["zoom"], source=body_request["source"]
43
  )
44
  app_logger.info(f"output body_response length:{len(body_response)}.")
45
  app_logger.debug(f"output body_response:{body_response}.")
src/io/lambda_helpers.py CHANGED
@@ -1,7 +1,7 @@
1
  """lambda helper functions"""
2
  from typing import Dict
3
- from aws_lambda_powertools.event_handler import content_types
4
  from xyzservices import providers
 
5
 
6
  from src import app_logger
7
  from src.io.coordinates_pixel_conversion import get_latlng_to_pixel_coordinates
@@ -52,6 +52,7 @@ def get_parsed_bbox_points(request_input: ApiRequestBody) -> Dict:
52
  Returns:
53
  dict with bounding box, prompt and zoom
54
  """
 
55
  app_logger.info(f"try to parsing input request {request_input}...")
56
 
57
  bbox = request_input.bbox
@@ -73,7 +74,7 @@ def get_parsed_bbox_points(request_input: ApiRequestBody) -> Dict:
73
  "bbox": [ne_latlng, sw_latlng],
74
  "prompt": new_prompt_list,
75
  "zoom": new_zoom,
76
- "url_tile": get_url_tile(request_input.source_type)
77
  }
78
 
79
 
@@ -153,9 +154,8 @@ def get_parsed_request_body(event: Dict or str) -> ApiRequestBody:
153
 
154
 
155
  def get_url_tile(source_type: str):
156
- from src.utilities.constants import DEFAULT_TMS_NAME, DEFAULT_TMS
157
 
158
- if source_type.lower() == DEFAULT_TMS_NAME:
159
- return DEFAULT_TMS
160
- providers_type = providers.query_name(source_type)
161
- return providers_type.url
 
1
  """lambda helper functions"""
2
  from typing import Dict
 
3
  from xyzservices import providers
4
+ from aws_lambda_powertools.event_handler import content_types
5
 
6
  from src import app_logger
7
  from src.io.coordinates_pixel_conversion import get_latlng_to_pixel_coordinates
 
52
  Returns:
53
  dict with bounding box, prompt and zoom
54
  """
55
+
56
  app_logger.info(f"try to parsing input request {request_input}...")
57
 
58
  bbox = request_input.bbox
 
74
  "bbox": [ne_latlng, sw_latlng],
75
  "prompt": new_prompt_list,
76
  "zoom": new_zoom,
77
+ "source": get_url_tile(request_input.source_type)
78
  }
79
 
80
 
 
154
 
155
 
156
  def get_url_tile(source_type: str):
157
+ from src.utilities.constants import DEFAULT_TMS_NAME, DEFAULT_TMS_NAME_SHORT
158
 
159
+ if source_type.lower() == DEFAULT_TMS_NAME_SHORT:
160
+ return providers.query_name(DEFAULT_TMS_NAME)
161
+ return providers.query_name(source_type)
 
src/io/tms2geotiff.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
  from numpy import ndarray
 
3
 
4
  from src import app_logger
5
  from src.utilities.constants import (OUTPUT_CRS_STRING, DRIVER_RASTERIO_GTIFF, N_MAX_RETRIES, N_CONNECTION, N_WAIT,
6
  ZOOM_AUTO, BOOL_USE_CACHE)
7
  from src.utilities.type_hints import tuple_ndarray_transform, tuple_float
8
 
9
-
10
  bool_use_cache = int(os.getenv("BOOL_USE_CACHE", BOOL_USE_CACHE))
11
  n_connection = int(os.getenv("N_CONNECTION", N_CONNECTION))
12
  n_max_retries = int(os.getenv("N_MAX_RETRIES", N_MAX_RETRIES))
@@ -14,7 +14,8 @@ n_wait = int(os.getenv("N_WAIT", N_WAIT))
14
  zoom_auto_string = os.getenv("ZOOM_AUTO", ZOOM_AUTO)
15
 
16
 
17
- def download_extent(w: float, s: float, e: float, n: float, zoom: int or str = zoom_auto_string, source: str = None,
 
18
  wait: int = n_wait, max_retries: int = n_max_retries, n_connections: int = n_connection,
19
  use_cache: bool = bool_use_cache) -> tuple_ndarray_transform:
20
  """
@@ -26,15 +27,11 @@ def download_extent(w: float, s: float, e: float, n: float, zoom: int or str = z
26
  e: East edge
27
  n: North edge
28
  zoom: Level of detail
29
- source: xyzservices.TileProvider object or str
30
- [Optional. Default: OpenStreetMap Humanitarian web tiles]
31
- The tile source: web tile provider or path to local file. The web tile
32
- provider can be in the form of a :class:`xyzservices.TileProvider` object or a
33
- URL. The placeholders for the XYZ in the URL need to be `{x}`, `{y}`,
34
- `{z}`, respectively. For local file paths, the file is read with
35
- `rasterio` and all bands are loaded into the basemap.
36
- IMPORTANT: tiles are assumed to be in the Spherical Mercator
37
- projection (EPSG:3857), unless the `crs` keyword is specified.
38
  wait: if the tile API is rate-limited, the number of seconds to wait
39
  between a failed request and the next try
40
  max_retries: total number of rejected requests allowed before contextily will stop trying to fetch more tiles
@@ -59,8 +56,8 @@ def download_extent(w: float, s: float, e: float, n: float, zoom: int or str = z
59
  app_logger.debug(f"download raster from source:{source} with bounding box w:{w}, s:{s}, e:{e}, n:{n}.")
60
  app_logger.debug(f"types w:{type(w)}, s:{type(s)}, e:{type(e)}, n:{type(n)}.")
61
  downloaded_raster, bbox_raster = contextily_tile.bounds2img(
62
- w, s, e, n, zoom=zoom, source=source, ll=True, wait=wait, max_retries=max_retries, n_connections=n_connections,
63
- use_cache=use_cache)
64
  xp0, yp0 = _from4326_to3857(n, e)
65
  xp1, yp1 = _from4326_to3857(s, w)
66
  cropped_image_ndarray, cropped_transform = crop_raster(yp1, xp1, yp0, xp0, downloaded_raster, bbox_raster)
 
1
  import os
2
  from numpy import ndarray
3
+ from xyzservices import TileProvider
4
 
5
  from src import app_logger
6
  from src.utilities.constants import (OUTPUT_CRS_STRING, DRIVER_RASTERIO_GTIFF, N_MAX_RETRIES, N_CONNECTION, N_WAIT,
7
  ZOOM_AUTO, BOOL_USE_CACHE)
8
  from src.utilities.type_hints import tuple_ndarray_transform, tuple_float
9
 
 
10
  bool_use_cache = int(os.getenv("BOOL_USE_CACHE", BOOL_USE_CACHE))
11
  n_connection = int(os.getenv("N_CONNECTION", N_CONNECTION))
12
  n_max_retries = int(os.getenv("N_MAX_RETRIES", N_MAX_RETRIES))
 
14
  zoom_auto_string = os.getenv("ZOOM_AUTO", ZOOM_AUTO)
15
 
16
 
17
+ def download_extent(w: float, s: float, e: float, n: float, zoom: int or str = zoom_auto_string,
18
+ source: TileProvider or str = None,
19
  wait: int = n_wait, max_retries: int = n_max_retries, n_connections: int = n_connection,
20
  use_cache: bool = bool_use_cache) -> tuple_ndarray_transform:
21
  """
 
27
  e: East edge
28
  n: North edge
29
  zoom: Level of detail
30
+ source: The tile source: web tile provider or path to local file. The web tile provider can be in the form of
31
+ a :class:`xyzservices.TileProvider` object or a URL. The placeholders for the XYZ in the URL need to be
32
+ `{x}`, `{y}`, `{z}`, respectively. For local file paths, the file is read with `rasterio` and all bands are
33
+ loaded into the basemap. IMPORTANT: tiles are assumed to be in the Spherical Mercator projection
34
+ (EPSG:3857), unless the `crs` keyword is specified.
 
 
 
 
35
  wait: if the tile API is rate-limited, the number of seconds to wait
36
  between a failed request and the next try
37
  max_retries: total number of rejected requests allowed before contextily will stop trying to fetch more tiles
 
56
  app_logger.debug(f"download raster from source:{source} with bounding box w:{w}, s:{s}, e:{e}, n:{n}.")
57
  app_logger.debug(f"types w:{type(w)}, s:{type(s)}, e:{type(e)}, n:{type(n)}.")
58
  downloaded_raster, bbox_raster = contextily_tile.bounds2img(
59
+ w, s, e, n, zoom=zoom, source=source, ll=True, wait=wait, max_retries=max_retries,
60
+ n_connections=n_connections, use_cache=use_cache)
61
  xp0, yp0 = _from4326_to3857(n, e)
62
  xp1, yp1 = _from4326_to3857(s, w)
63
  cropped_image_ndarray, cropped_transform = crop_raster(yp1, xp1, yp0, xp0, downloaded_raster, bbox_raster)
src/prediction_api/predictors.py CHANGED
@@ -16,7 +16,7 @@ def samexporter_predict(
16
  prompt: list_dict,
17
  zoom: float,
18
  model_name: str = "fastsam",
19
- url_tile: str = DEFAULT_TMS
20
  ) -> dict_str_int:
21
  """
22
  Return predictions as a geojson from a geo-referenced image using the given input prompt.
@@ -31,7 +31,7 @@ def samexporter_predict(
31
  prompt: machine learning input prompt
32
  zoom: Level of detail
33
  model_name: machine learning model name
34
- url_tile: server url tile
35
 
36
  Returns:
37
  Affine transform
@@ -47,8 +47,8 @@ def samexporter_predict(
47
  models_instance = models_dict[model_name]["instance"]
48
 
49
  pt0, pt1 = bbox
50
- app_logger.info(f"tile_source: {url_tile}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
51
- img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=url_tile)
52
  app_logger.info(
53
  f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
54
 
 
16
  prompt: list_dict,
17
  zoom: float,
18
  model_name: str = "fastsam",
19
+ source: str = DEFAULT_TMS
20
  ) -> dict_str_int:
21
  """
22
  Return predictions as a geojson from a geo-referenced image using the given input prompt.
 
31
  prompt: machine learning input prompt
32
  zoom: Level of detail
33
  model_name: machine learning model name
34
+ source: xyz
35
 
36
  Returns:
37
  Affine transform
 
47
  models_instance = models_dict[model_name]["instance"]
48
 
49
  pt0, pt1 = bbox
50
+ app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
51
+ img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=source)
52
  app_logger.info(
53
  f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
54
 
src/utilities/constants.py CHANGED
@@ -13,7 +13,8 @@ MODEL_ENCODER_NAME = "mobile_sam.encoder.onnx"
13
  MODEL_DECODER_NAME = "sam_vit_h_4b8939.decoder.onnx"
14
  TILE_SIZE = 256
15
  EARTH_EQUATORIAL_RADIUS = 6378137.0
16
- DEFAULT_TMS_NAME = "openstreetmap"
 
17
  DEFAULT_TMS = 'https://tile.openstreetmap.org/{z}/{x}/{y}.png'
18
  WKT_3857 = 'PROJCS["WGS 84 / Pseudo-Mercator",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,'
19
  WKT_3857 += 'AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],'
 
13
  MODEL_DECODER_NAME = "sam_vit_h_4b8939.decoder.onnx"
14
  TILE_SIZE = 256
15
  EARTH_EQUATORIAL_RADIUS = 6378137.0
16
+ DEFAULT_TMS_NAME_SHORT = "openstreetmap"
17
+ DEFAULT_TMS_NAME = "OpenStreetMap.Mapnik"
18
  DEFAULT_TMS = 'https://tile.openstreetmap.org/{z}/{x}/{y}.png'
19
  WKT_3857 = 'PROJCS["WGS 84 / Pseudo-Mercator",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,'
20
  WKT_3857 += 'AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],'
src/utilities/type_hints.py CHANGED
@@ -81,9 +81,8 @@ class ApiRequestBody(BaseModel):
81
  bbox: RawBBox
82
  prompt: list[RawPromptPoint | RawPromptRectangle]
83
  zoom: int | float
84
- source_type: str = "Satellite"
85
  debug: bool = False
86
- url_tile: str = DEFAULT_TMS
87
 
88
 
89
  class ApiResponseBodyFailure(BaseModel):
 
81
  bbox: RawBBox
82
  prompt: list[RawPromptPoint | RawPromptRectangle]
83
  zoom: int | float
84
+ source_type: str = "OpenStreetMap.Mapnik"
85
  debug: bool = False
 
86
 
87
 
88
  class ApiResponseBodyFailure(BaseModel):