[fix] handle source types with additional parameters other than x,y,z; set default source_type = OpenStreetMap.Mapnik
Browse files- src/app.py +1 -1
- src/io/lambda_helpers.py +7 -7
- src/io/tms2geotiff.py +10 -13
- src/prediction_api/predictors.py +4 -4
- src/utilities/constants.py +2 -1
- src/utilities/type_hints.py +1 -2
src/app.py
CHANGED
@@ -39,7 +39,7 @@ def lambda_handler(event: Dict, context: LambdaContext) -> str:
|
|
39 |
|
40 |
try:
|
41 |
body_response = samexporter_predict(
|
42 |
-
body_request["bbox"], body_request["prompt"], body_request["zoom"],
|
43 |
)
|
44 |
app_logger.info(f"output body_response length:{len(body_response)}.")
|
45 |
app_logger.debug(f"output body_response:{body_response}.")
|
|
|
39 |
|
40 |
try:
|
41 |
body_response = samexporter_predict(
|
42 |
+
body_request["bbox"], body_request["prompt"], body_request["zoom"], source=body_request["source"]
|
43 |
)
|
44 |
app_logger.info(f"output body_response length:{len(body_response)}.")
|
45 |
app_logger.debug(f"output body_response:{body_response}.")
|
src/io/lambda_helpers.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
"""lambda helper functions"""
|
2 |
from typing import Dict
|
3 |
-
from aws_lambda_powertools.event_handler import content_types
|
4 |
from xyzservices import providers
|
|
|
5 |
|
6 |
from src import app_logger
|
7 |
from src.io.coordinates_pixel_conversion import get_latlng_to_pixel_coordinates
|
@@ -52,6 +52,7 @@ def get_parsed_bbox_points(request_input: ApiRequestBody) -> Dict:
|
|
52 |
Returns:
|
53 |
dict with bounding box, prompt and zoom
|
54 |
"""
|
|
|
55 |
app_logger.info(f"try to parsing input request {request_input}...")
|
56 |
|
57 |
bbox = request_input.bbox
|
@@ -73,7 +74,7 @@ def get_parsed_bbox_points(request_input: ApiRequestBody) -> Dict:
|
|
73 |
"bbox": [ne_latlng, sw_latlng],
|
74 |
"prompt": new_prompt_list,
|
75 |
"zoom": new_zoom,
|
76 |
-
"
|
77 |
}
|
78 |
|
79 |
|
@@ -153,9 +154,8 @@ def get_parsed_request_body(event: Dict or str) -> ApiRequestBody:
|
|
153 |
|
154 |
|
155 |
def get_url_tile(source_type: str):
|
156 |
-
from src.utilities.constants import DEFAULT_TMS_NAME,
|
157 |
|
158 |
-
if source_type.lower() ==
|
159 |
-
return
|
160 |
-
|
161 |
-
return providers_type.url
|
|
|
1 |
"""lambda helper functions"""
|
2 |
from typing import Dict
|
|
|
3 |
from xyzservices import providers
|
4 |
+
from aws_lambda_powertools.event_handler import content_types
|
5 |
|
6 |
from src import app_logger
|
7 |
from src.io.coordinates_pixel_conversion import get_latlng_to_pixel_coordinates
|
|
|
52 |
Returns:
|
53 |
dict with bounding box, prompt and zoom
|
54 |
"""
|
55 |
+
|
56 |
app_logger.info(f"try to parsing input request {request_input}...")
|
57 |
|
58 |
bbox = request_input.bbox
|
|
|
74 |
"bbox": [ne_latlng, sw_latlng],
|
75 |
"prompt": new_prompt_list,
|
76 |
"zoom": new_zoom,
|
77 |
+
"source": get_url_tile(request_input.source_type)
|
78 |
}
|
79 |
|
80 |
|
|
|
154 |
|
155 |
|
156 |
def get_url_tile(source_type: str):
|
157 |
+
from src.utilities.constants import DEFAULT_TMS_NAME, DEFAULT_TMS_NAME_SHORT
|
158 |
|
159 |
+
if source_type.lower() == DEFAULT_TMS_NAME_SHORT:
|
160 |
+
return providers.query_name(DEFAULT_TMS_NAME)
|
161 |
+
return providers.query_name(source_type)
|
|
src/io/tms2geotiff.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import os
|
2 |
from numpy import ndarray
|
|
|
3 |
|
4 |
from src import app_logger
|
5 |
from src.utilities.constants import (OUTPUT_CRS_STRING, DRIVER_RASTERIO_GTIFF, N_MAX_RETRIES, N_CONNECTION, N_WAIT,
|
6 |
ZOOM_AUTO, BOOL_USE_CACHE)
|
7 |
from src.utilities.type_hints import tuple_ndarray_transform, tuple_float
|
8 |
|
9 |
-
|
10 |
bool_use_cache = int(os.getenv("BOOL_USE_CACHE", BOOL_USE_CACHE))
|
11 |
n_connection = int(os.getenv("N_CONNECTION", N_CONNECTION))
|
12 |
n_max_retries = int(os.getenv("N_MAX_RETRIES", N_MAX_RETRIES))
|
@@ -14,7 +14,8 @@ n_wait = int(os.getenv("N_WAIT", N_WAIT))
|
|
14 |
zoom_auto_string = os.getenv("ZOOM_AUTO", ZOOM_AUTO)
|
15 |
|
16 |
|
17 |
-
def download_extent(w: float, s: float, e: float, n: float, zoom: int or str = zoom_auto_string,
|
|
|
18 |
wait: int = n_wait, max_retries: int = n_max_retries, n_connections: int = n_connection,
|
19 |
use_cache: bool = bool_use_cache) -> tuple_ndarray_transform:
|
20 |
"""
|
@@ -26,15 +27,11 @@ def download_extent(w: float, s: float, e: float, n: float, zoom: int or str = z
|
|
26 |
e: East edge
|
27 |
n: North edge
|
28 |
zoom: Level of detail
|
29 |
-
source:
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
`{z}`, respectively. For local file paths, the file is read with
|
35 |
-
`rasterio` and all bands are loaded into the basemap.
|
36 |
-
IMPORTANT: tiles are assumed to be in the Spherical Mercator
|
37 |
-
projection (EPSG:3857), unless the `crs` keyword is specified.
|
38 |
wait: if the tile API is rate-limited, the number of seconds to wait
|
39 |
between a failed request and the next try
|
40 |
max_retries: total number of rejected requests allowed before contextily will stop trying to fetch more tiles
|
@@ -59,8 +56,8 @@ def download_extent(w: float, s: float, e: float, n: float, zoom: int or str = z
|
|
59 |
app_logger.debug(f"download raster from source:{source} with bounding box w:{w}, s:{s}, e:{e}, n:{n}.")
|
60 |
app_logger.debug(f"types w:{type(w)}, s:{type(s)}, e:{type(e)}, n:{type(n)}.")
|
61 |
downloaded_raster, bbox_raster = contextily_tile.bounds2img(
|
62 |
-
w, s, e, n, zoom=zoom, source=source, ll=True, wait=wait, max_retries=max_retries,
|
63 |
-
use_cache=use_cache)
|
64 |
xp0, yp0 = _from4326_to3857(n, e)
|
65 |
xp1, yp1 = _from4326_to3857(s, w)
|
66 |
cropped_image_ndarray, cropped_transform = crop_raster(yp1, xp1, yp0, xp0, downloaded_raster, bbox_raster)
|
|
|
1 |
import os
|
2 |
from numpy import ndarray
|
3 |
+
from xyzservices import TileProvider
|
4 |
|
5 |
from src import app_logger
|
6 |
from src.utilities.constants import (OUTPUT_CRS_STRING, DRIVER_RASTERIO_GTIFF, N_MAX_RETRIES, N_CONNECTION, N_WAIT,
|
7 |
ZOOM_AUTO, BOOL_USE_CACHE)
|
8 |
from src.utilities.type_hints import tuple_ndarray_transform, tuple_float
|
9 |
|
|
|
10 |
bool_use_cache = int(os.getenv("BOOL_USE_CACHE", BOOL_USE_CACHE))
|
11 |
n_connection = int(os.getenv("N_CONNECTION", N_CONNECTION))
|
12 |
n_max_retries = int(os.getenv("N_MAX_RETRIES", N_MAX_RETRIES))
|
|
|
14 |
zoom_auto_string = os.getenv("ZOOM_AUTO", ZOOM_AUTO)
|
15 |
|
16 |
|
17 |
+
def download_extent(w: float, s: float, e: float, n: float, zoom: int or str = zoom_auto_string,
|
18 |
+
source: TileProvider or str = None,
|
19 |
wait: int = n_wait, max_retries: int = n_max_retries, n_connections: int = n_connection,
|
20 |
use_cache: bool = bool_use_cache) -> tuple_ndarray_transform:
|
21 |
"""
|
|
|
27 |
e: East edge
|
28 |
n: North edge
|
29 |
zoom: Level of detail
|
30 |
+
source: The tile source: web tile provider or path to local file. The web tile provider can be in the form of
|
31 |
+
a :class:`xyzservices.TileProvider` object or a URL. The placeholders for the XYZ in the URL need to be
|
32 |
+
`{x}`, `{y}`, `{z}`, respectively. For local file paths, the file is read with `rasterio` and all bands are
|
33 |
+
loaded into the basemap. IMPORTANT: tiles are assumed to be in the Spherical Mercator projection
|
34 |
+
(EPSG:3857), unless the `crs` keyword is specified.
|
|
|
|
|
|
|
|
|
35 |
wait: if the tile API is rate-limited, the number of seconds to wait
|
36 |
between a failed request and the next try
|
37 |
max_retries: total number of rejected requests allowed before contextily will stop trying to fetch more tiles
|
|
|
56 |
app_logger.debug(f"download raster from source:{source} with bounding box w:{w}, s:{s}, e:{e}, n:{n}.")
|
57 |
app_logger.debug(f"types w:{type(w)}, s:{type(s)}, e:{type(e)}, n:{type(n)}.")
|
58 |
downloaded_raster, bbox_raster = contextily_tile.bounds2img(
|
59 |
+
w, s, e, n, zoom=zoom, source=source, ll=True, wait=wait, max_retries=max_retries,
|
60 |
+
n_connections=n_connections, use_cache=use_cache)
|
61 |
xp0, yp0 = _from4326_to3857(n, e)
|
62 |
xp1, yp1 = _from4326_to3857(s, w)
|
63 |
cropped_image_ndarray, cropped_transform = crop_raster(yp1, xp1, yp0, xp0, downloaded_raster, bbox_raster)
|
src/prediction_api/predictors.py
CHANGED
@@ -16,7 +16,7 @@ def samexporter_predict(
|
|
16 |
prompt: list_dict,
|
17 |
zoom: float,
|
18 |
model_name: str = "fastsam",
|
19 |
-
|
20 |
) -> dict_str_int:
|
21 |
"""
|
22 |
Return predictions as a geojson from a geo-referenced image using the given input prompt.
|
@@ -31,7 +31,7 @@ def samexporter_predict(
|
|
31 |
prompt: machine learning input prompt
|
32 |
zoom: Level of detail
|
33 |
model_name: machine learning model name
|
34 |
-
|
35 |
|
36 |
Returns:
|
37 |
Affine transform
|
@@ -47,8 +47,8 @@ def samexporter_predict(
|
|
47 |
models_instance = models_dict[model_name]["instance"]
|
48 |
|
49 |
pt0, pt1 = bbox
|
50 |
-
app_logger.info(f"tile_source: {
|
51 |
-
img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=
|
52 |
app_logger.info(
|
53 |
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
|
54 |
|
|
|
16 |
prompt: list_dict,
|
17 |
zoom: float,
|
18 |
model_name: str = "fastsam",
|
19 |
+
source: str = DEFAULT_TMS
|
20 |
) -> dict_str_int:
|
21 |
"""
|
22 |
Return predictions as a geojson from a geo-referenced image using the given input prompt.
|
|
|
31 |
prompt: machine learning input prompt
|
32 |
zoom: Level of detail
|
33 |
model_name: machine learning model name
|
34 |
+
source: xyz
|
35 |
|
36 |
Returns:
|
37 |
Affine transform
|
|
|
47 |
models_instance = models_dict[model_name]["instance"]
|
48 |
|
49 |
pt0, pt1 = bbox
|
50 |
+
app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
|
51 |
+
img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=source)
|
52 |
app_logger.info(
|
53 |
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
|
54 |
|
src/utilities/constants.py
CHANGED
@@ -13,7 +13,8 @@ MODEL_ENCODER_NAME = "mobile_sam.encoder.onnx"
|
|
13 |
MODEL_DECODER_NAME = "sam_vit_h_4b8939.decoder.onnx"
|
14 |
TILE_SIZE = 256
|
15 |
EARTH_EQUATORIAL_RADIUS = 6378137.0
|
16 |
-
|
|
|
17 |
DEFAULT_TMS = 'https://tile.openstreetmap.org/{z}/{x}/{y}.png'
|
18 |
WKT_3857 = 'PROJCS["WGS 84 / Pseudo-Mercator",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,'
|
19 |
WKT_3857 += 'AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],'
|
|
|
13 |
MODEL_DECODER_NAME = "sam_vit_h_4b8939.decoder.onnx"
|
14 |
TILE_SIZE = 256
|
15 |
EARTH_EQUATORIAL_RADIUS = 6378137.0
|
16 |
+
DEFAULT_TMS_NAME_SHORT = "openstreetmap"
|
17 |
+
DEFAULT_TMS_NAME = "OpenStreetMap.Mapnik"
|
18 |
DEFAULT_TMS = 'https://tile.openstreetmap.org/{z}/{x}/{y}.png'
|
19 |
WKT_3857 = 'PROJCS["WGS 84 / Pseudo-Mercator",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,'
|
20 |
WKT_3857 += 'AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],'
|
src/utilities/type_hints.py
CHANGED
@@ -81,9 +81,8 @@ class ApiRequestBody(BaseModel):
|
|
81 |
bbox: RawBBox
|
82 |
prompt: list[RawPromptPoint | RawPromptRectangle]
|
83 |
zoom: int | float
|
84 |
-
source_type: str = "
|
85 |
debug: bool = False
|
86 |
-
url_tile: str = DEFAULT_TMS
|
87 |
|
88 |
|
89 |
class ApiResponseBodyFailure(BaseModel):
|
|
|
81 |
bbox: RawBBox
|
82 |
prompt: list[RawPromptPoint | RawPromptRectangle]
|
83 |
zoom: int | float
|
84 |
+
source_type: str = "OpenStreetMap.Mapnik"
|
85 |
debug: bool = False
|
|
|
86 |
|
87 |
|
88 |
class ApiResponseBodyFailure(BaseModel):
|