alessandro trinca tornidor
commited on
Commit
·
62c986b
1
Parent(s):
2dbe20c
refactor: add and remove all needed files because of switch from docker to gradio sdk and samgis-lisa
Browse files- dockerfiles/dockerfile-lisa-predictions +0 -1
- requirements_no_versions.txt +2 -0
- samgis_lisa_on_cuda/__init__.py +0 -27
- samgis_lisa_on_cuda/__version__.py +0 -4
- samgis_lisa_on_cuda/io/__init__.py +0 -1
- samgis_lisa_on_cuda/io/coordinates_pixel_conversion.py +0 -99
- samgis_lisa_on_cuda/io/geo_helpers.py +0 -91
- samgis_lisa_on_cuda/io/raster_helpers.py +0 -331
- samgis_lisa_on_cuda/io/tms2geotiff.py +0 -181
- samgis_lisa_on_cuda/io/wrappers_helpers.py +0 -262
- samgis_lisa_on_cuda/prediction_api/__init__.py +0 -1
- samgis_lisa_on_cuda/prediction_api/global_models.py +0 -6
- samgis_lisa_on_cuda/prediction_api/lisa.py +0 -77
- samgis_lisa_on_cuda/prediction_api/predictors.py +0 -72
- samgis_lisa_on_cuda/utilities/__init__.py +0 -1
- samgis_lisa_on_cuda/utilities/constants.py +0 -39
- samgis_lisa_on_cuda/utilities/session_logger.py +0 -65
- samgis_lisa_on_cuda/utilities/type_hints.py +0 -111
- scripts/create_folders_and_variables_if_not_exists.py +0 -55
- scripts/extract-openapi-fastapi.py +2 -1
- scripts/extract-openapi-lambda.py +0 -13
- tests/__init__.py +2 -1
- tests/events/lambda_handler/10/550/391.png +0 -0
- tests/events/lambda_handler/10/550/392.png +0 -0
- tests/events/lambda_handler/10/550/393.png +0 -0
- tests/events/lambda_handler/10/551/391.png +0 -0
- tests/events/lambda_handler/10/551/392.png +0 -0
- tests/events/lambda_handler/10/551/393.png +0 -0
- tests/events/lambda_handler/10/552/391.png +0 -0
- tests/events/lambda_handler/10/552/392.png +0 -0
- tests/events/lambda_handler/10/552/393.png +0 -0
- tests/events/lambda_handler/10/553/391.png +0 -0
- tests/events/lambda_handler/10/553/392.png +0 -0
- tests/events/lambda_handler/10/553/393.png +0 -0
- tests/events/lambda_handler/10/554/391.png +0 -0
- tests/events/lambda_handler/10/554/392.png +0 -0
- tests/events/lambda_handler/10/554/393.png +0 -0
- tests/io/test_coordinates_pixel_conversion.py +0 -27
- tests/io/test_geo_helpers.py +0 -103
- tests/io/test_raster_helpers.py +0 -255
- tests/io/test_tms2geotiff.py +0 -138
- tests/io/test_wrappers_helpers.py +0 -135
- tests/local_tiles_http_server.py +0 -46
- tests/prediction_api/__init__.py +0 -0
- tests/prediction_api/test_predictors.py +0 -64
- tests/{test_fastapi_app.py → test_app.py} +19 -22
- tests/test_lambda_app.py +0 -232
- wrappers/__init__.py +0 -0
- wrappers/fastapi_wrapper.py +0 -273
- wrappers/lambda_wrapper.py +0 -58
dockerfiles/dockerfile-lisa-predictions
CHANGED
@@ -23,7 +23,6 @@ RUN ls -l ${LAMBDA_TASK_ROOT}
|
|
23 |
RUN ls -ld ${LAMBDA_TASK_ROOT}
|
24 |
RUN ls -l ${LAMBDA_TASK_ROOT}/machine_learning_models
|
25 |
RUN python -c "import sys; print(sys.path)"
|
26 |
-
RUN python -c "import cv2"
|
27 |
RUN python -c "import fastapi"
|
28 |
RUN python -c "import geopandas"
|
29 |
RUN python -c "import loguru"
|
|
|
23 |
RUN ls -ld ${LAMBDA_TASK_ROOT}
|
24 |
RUN ls -l ${LAMBDA_TASK_ROOT}/machine_learning_models
|
25 |
RUN python -c "import sys; print(sys.path)"
|
|
|
26 |
RUN python -c "import fastapi"
|
27 |
RUN python -c "import geopandas"
|
28 |
RUN python -c "import loguru"
|
requirements_no_versions.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
samgis-lisa
|
2 |
+
samgis_lisa
|
samgis_lisa_on_cuda/__init__.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
"""Get machine learning predictions from geodata raster images"""
|
2 |
-
import os
|
3 |
-
|
4 |
-
# not used here but contextily_tile is imported in samgis_lisa_on_cuda.io.tms2geotiff
|
5 |
-
from contextily import tile as contextily_tile
|
6 |
-
from pathlib import Path
|
7 |
-
from samgis_lisa_on_cuda.utilities.constants import SERVICE_NAME
|
8 |
-
|
9 |
-
ROOT = Path(globals().get("__file__", "./_")).absolute().parent.parent
|
10 |
-
PROJECT_ROOT_FOLDER = Path(os.getenv("PROJECT_ROOT_FOLDER", ROOT))
|
11 |
-
WORKDIR = Path(os.getenv("WORKDIR", ROOT))
|
12 |
-
MODEL_FOLDER_PROJECT_ROOT_FOLDER = Path(PROJECT_ROOT_FOLDER / "machine_learning_models")
|
13 |
-
MODEL_FOLDER = Path(os.getenv("MODEL_FOLDER", MODEL_FOLDER_PROJECT_ROOT_FOLDER))
|
14 |
-
|
15 |
-
IS_AWS_LAMBDA = bool(os.getenv("IS_AWS_LAMBDA", ""))
|
16 |
-
|
17 |
-
if IS_AWS_LAMBDA:
|
18 |
-
try:
|
19 |
-
from aws_lambda_powertools import Logger
|
20 |
-
|
21 |
-
app_logger = Logger(service=SERVICE_NAME)
|
22 |
-
except ModuleNotFoundError:
|
23 |
-
print("this should be AWS LAMBDA environment but we miss the required aws lambda powertools package")
|
24 |
-
else:
|
25 |
-
from samgis_core.utilities.fastapi_logger import setup_logging
|
26 |
-
|
27 |
-
app_logger = setup_logging(debug=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/__version__.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
import importlib.metadata
|
2 |
-
|
3 |
-
|
4 |
-
__version__ = importlib.metadata.version(__package__ or __name__)
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/io/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
"""input/output helpers functions"""
|
|
|
|
samgis_lisa_on_cuda/io/coordinates_pixel_conversion.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
"""functions useful to convert to/from latitude-longitude coordinates to pixel image coordinates"""
|
2 |
-
from samgis_core.utilities.type_hints import TupleFloat, TupleFloatAny
|
3 |
-
|
4 |
-
from samgis_lisa_on_cuda import app_logger
|
5 |
-
from samgis_lisa_on_cuda.utilities.constants import TILE_SIZE, EARTH_EQUATORIAL_RADIUS
|
6 |
-
from samgis_lisa_on_cuda.utilities.type_hints import ImagePixelCoordinates
|
7 |
-
from samgis_lisa_on_cuda.utilities.type_hints import LatLngDict
|
8 |
-
|
9 |
-
|
10 |
-
def _get_latlng2pixel_projection(latlng: LatLngDict) -> ImagePixelCoordinates:
|
11 |
-
from math import log, pi, sin
|
12 |
-
|
13 |
-
app_logger.debug(f"latlng: {type(latlng)}, value:{latlng}.")
|
14 |
-
app_logger.debug(f'latlng lat: {type(latlng.lat)}, value:{latlng.lat}.')
|
15 |
-
app_logger.debug(f'latlng lng: {type(latlng.lng)}, value:{latlng.lng}.')
|
16 |
-
try:
|
17 |
-
sin_y: float = sin(latlng.lat * pi / 180)
|
18 |
-
app_logger.debug(f"sin_y, #1:{sin_y}.")
|
19 |
-
sin_y = min(max(sin_y, -0.9999), 0.9999)
|
20 |
-
app_logger.debug(f"sin_y, #2:{sin_y}.")
|
21 |
-
x = TILE_SIZE * (0.5 + latlng.lng / 360)
|
22 |
-
app_logger.debug(f"x:{x}.")
|
23 |
-
y = TILE_SIZE * (0.5 - log((1 + sin_y) / (1 - sin_y)) / (4 * pi))
|
24 |
-
app_logger.debug(f"y:{y}.")
|
25 |
-
|
26 |
-
return {"x": x, "y": y}
|
27 |
-
except Exception as e_get_latlng2pixel_projection:
|
28 |
-
app_logger.error(f'args type:{type(latlng)}, {latlng}.')
|
29 |
-
app_logger.exception(f'e_get_latlng2pixel_projection:{e_get_latlng2pixel_projection}.', exc_info=True)
|
30 |
-
raise e_get_latlng2pixel_projection
|
31 |
-
|
32 |
-
|
33 |
-
def _get_point_latlng_to_pixel_coordinates(latlng: LatLngDict, zoom: int | float) -> ImagePixelCoordinates:
|
34 |
-
from math import floor
|
35 |
-
|
36 |
-
try:
|
37 |
-
world_coordinate: ImagePixelCoordinates = _get_latlng2pixel_projection(latlng)
|
38 |
-
app_logger.debug(f"world_coordinate:{world_coordinate}.")
|
39 |
-
scale: int = pow(2, zoom)
|
40 |
-
app_logger.debug(f"scale:{scale}.")
|
41 |
-
return ImagePixelCoordinates(
|
42 |
-
x=floor(world_coordinate["x"] * scale),
|
43 |
-
y=floor(world_coordinate["y"] * scale)
|
44 |
-
)
|
45 |
-
except Exception as e_format_latlng_to_pixel_coordinates:
|
46 |
-
app_logger.error(f'latlng type:{type(latlng)}, {latlng}.')
|
47 |
-
app_logger.error(f'zoom type:{type(zoom)}, {zoom}.')
|
48 |
-
app_logger.exception(f'e_format_latlng_to_pixel_coordinates:{e_format_latlng_to_pixel_coordinates}.',
|
49 |
-
exc_info=True)
|
50 |
-
raise e_format_latlng_to_pixel_coordinates
|
51 |
-
|
52 |
-
|
53 |
-
def get_latlng_to_pixel_coordinates(
|
54 |
-
latlng_origin_ne: LatLngDict,
|
55 |
-
latlng_origin_sw: LatLngDict,
|
56 |
-
latlng_current_point: LatLngDict,
|
57 |
-
zoom: int | float,
|
58 |
-
k: str
|
59 |
-
) -> ImagePixelCoordinates:
|
60 |
-
"""
|
61 |
-
Parse the input request lambda event
|
62 |
-
|
63 |
-
Args:
|
64 |
-
latlng_origin_ne: NE latitude-longitude origin point
|
65 |
-
latlng_origin_sw: SW latitude-longitude origin point
|
66 |
-
latlng_current_point: latitude-longitude prompt point
|
67 |
-
zoom: Level of detail
|
68 |
-
k: prompt type
|
69 |
-
|
70 |
-
Returns:
|
71 |
-
ImagePixelCoordinates: pixel image coordinate point
|
72 |
-
"""
|
73 |
-
app_logger.debug(f"latlng_origin - {k}: {type(latlng_origin_ne)}, value:{latlng_origin_ne}.")
|
74 |
-
app_logger.debug(f"latlng_current_point - {k}: {type(latlng_current_point)}, value:{latlng_current_point}.")
|
75 |
-
latlng_map_origin_ne = _get_point_latlng_to_pixel_coordinates(latlng_origin_ne, zoom)
|
76 |
-
latlng_map_origin_sw = _get_point_latlng_to_pixel_coordinates(latlng_origin_sw, zoom)
|
77 |
-
latlng_map_current_point = _get_point_latlng_to_pixel_coordinates(latlng_current_point, zoom)
|
78 |
-
diff_coord_x = abs(latlng_map_origin_sw["x"] - latlng_map_current_point["x"])
|
79 |
-
diff_coord_y = abs(latlng_map_origin_ne["y"] - latlng_map_current_point["y"])
|
80 |
-
point = ImagePixelCoordinates(x=diff_coord_x, y=diff_coord_y)
|
81 |
-
app_logger.debug(f"point type - {k}: {point}.")
|
82 |
-
return point
|
83 |
-
|
84 |
-
|
85 |
-
def _from4326_to3857(lat: float, lon: float) -> TupleFloat or TupleFloatAny:
|
86 |
-
from math import radians, log, tan
|
87 |
-
|
88 |
-
x_tile: float = radians(lon) * EARTH_EQUATORIAL_RADIUS
|
89 |
-
y_tile: float = log(tan(radians(45 + lat / 2.0))) * EARTH_EQUATORIAL_RADIUS
|
90 |
-
return x_tile, y_tile
|
91 |
-
|
92 |
-
|
93 |
-
def _deg2num(lat: float, lon: float, zoom: int):
|
94 |
-
from math import radians, pi, asinh, tan
|
95 |
-
|
96 |
-
n = 2 ** zoom
|
97 |
-
x_tile = ((lon + 180) / 360 * n)
|
98 |
-
y_tile = (1 - asinh(tan(radians(lat))) / pi) * n / 2
|
99 |
-
return x_tile, y_tile
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/io/geo_helpers.py
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
"""handle geo-referenced raster images"""
|
2 |
-
from affine import Affine
|
3 |
-
from numpy import ndarray as np_ndarray
|
4 |
-
|
5 |
-
from samgis_core.utilities.type_hints import ListFloat, TupleFloat, DictStrInt
|
6 |
-
from samgis_lisa_on_cuda import app_logger
|
7 |
-
|
8 |
-
|
9 |
-
def load_affine_transformation_from_matrix(matrix_source_coefficients: ListFloat) -> Affine:
|
10 |
-
"""
|
11 |
-
Wrapper for rasterio.Affine.from_gdal() method
|
12 |
-
|
13 |
-
Args:
|
14 |
-
matrix_source_coefficients: 6 floats ordered by GDAL.
|
15 |
-
|
16 |
-
Returns:
|
17 |
-
Affine transform
|
18 |
-
"""
|
19 |
-
|
20 |
-
if len(matrix_source_coefficients) != 6:
|
21 |
-
raise ValueError(f"Expected 6 coefficients, found {len(matrix_source_coefficients)}; "
|
22 |
-
f"argument type: {type(matrix_source_coefficients)}.")
|
23 |
-
|
24 |
-
try:
|
25 |
-
a, d, b, e, c, f = (float(x) for x in matrix_source_coefficients)
|
26 |
-
center = tuple.__new__(Affine, [a, b, c, d, e, f, 0.0, 0.0, 1.0])
|
27 |
-
return center * Affine.translation(-0.5, -0.5)
|
28 |
-
except Exception as e:
|
29 |
-
app_logger.exception(f"exception:{e}, check updates on https://github.com/rasterio/affine",
|
30 |
-
extra=e,
|
31 |
-
stack_info=True, exc_info=True)
|
32 |
-
raise e
|
33 |
-
|
34 |
-
|
35 |
-
def get_affine_transform_from_gdal(matrix_source_coefficients: ListFloat or TupleFloat) -> Affine:
|
36 |
-
"""wrapper for rasterio Affine from_gdal method
|
37 |
-
|
38 |
-
Args:
|
39 |
-
matrix_source_coefficients: 6 floats ordered by GDAL.
|
40 |
-
|
41 |
-
Returns:
|
42 |
-
Affine transform
|
43 |
-
"""
|
44 |
-
return Affine.from_gdal(*matrix_source_coefficients)
|
45 |
-
|
46 |
-
|
47 |
-
def get_vectorized_raster_as_geojson(mask: np_ndarray, transform: TupleFloat) -> DictStrInt:
|
48 |
-
"""
|
49 |
-
Get shapes and values of connected regions in a dataset or array
|
50 |
-
|
51 |
-
Args:
|
52 |
-
mask: numpy mask
|
53 |
-
transform: tuple of float to transform into an Affine transform
|
54 |
-
|
55 |
-
Returns:
|
56 |
-
dict containing the output geojson and the predictions number
|
57 |
-
"""
|
58 |
-
try:
|
59 |
-
from rasterio.features import shapes
|
60 |
-
from geopandas import GeoDataFrame
|
61 |
-
|
62 |
-
app_logger.debug(f"matrix to consume with rasterio.shapes: {type(transform)}, {transform}.")
|
63 |
-
|
64 |
-
# old value for mask => band != 0
|
65 |
-
shapes_generator = ({
|
66 |
-
'properties': {'raster_val': v}, 'geometry': s}
|
67 |
-
for i, (s, v)
|
68 |
-
# instead of `enumerate(shapes(mask, mask=(band != 0), transform=rio_src.transform))`
|
69 |
-
# use mask=None to avoid using source
|
70 |
-
in enumerate(shapes(mask, mask=None, transform=transform))
|
71 |
-
)
|
72 |
-
app_logger.info("created shapes_generator, transform it to a polygon list...")
|
73 |
-
shapes_list = list(shapes_generator)
|
74 |
-
app_logger.info(f"created {len(shapes_list)} polygons.")
|
75 |
-
gpd_polygonized_raster = GeoDataFrame.from_features(shapes_list, crs="EPSG:3857")
|
76 |
-
app_logger.info("created a GeoDataFrame, export to geojson...")
|
77 |
-
geojson = gpd_polygonized_raster.to_json(to_wgs84=True)
|
78 |
-
app_logger.info("created geojson, preparing API response...")
|
79 |
-
return {
|
80 |
-
"geojson": geojson,
|
81 |
-
"n_shapes_geojson": len(shapes_list)
|
82 |
-
}
|
83 |
-
except Exception as e_shape_band:
|
84 |
-
try:
|
85 |
-
app_logger.error(f"mask type:{type(mask)}.")
|
86 |
-
app_logger.error(f"transform type:{type(transform)}, {transform}.")
|
87 |
-
app_logger.error(f"mask shape:{mask.shape}, dtype:{mask.dtype}.")
|
88 |
-
except Exception as e_shape_dtype:
|
89 |
-
app_logger.exception(f"mask shape or dtype not found:{e_shape_dtype}.", exc_info=True)
|
90 |
-
app_logger.exception(f"e_shape_band:{e_shape_band}.", exc_info=True)
|
91 |
-
raise e_shape_band
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/io/raster_helpers.py
DELETED
@@ -1,331 +0,0 @@
|
|
1 |
-
"""helpers for computer vision duties"""
|
2 |
-
import numpy as np
|
3 |
-
from numpy import ndarray, bitwise_not
|
4 |
-
from rasterio import open as rasterio_open
|
5 |
-
|
6 |
-
from samgis_lisa_on_cuda import PROJECT_ROOT_FOLDER
|
7 |
-
from samgis_lisa_on_cuda import app_logger
|
8 |
-
from samgis_lisa_on_cuda.utilities.constants import OUTPUT_CRS_STRING
|
9 |
-
from samgis_lisa_on_cuda.utilities.type_hints import XYZTerrainProvidersNames
|
10 |
-
|
11 |
-
|
12 |
-
def get_nextzen_terrain_rgb_formula(red: ndarray, green: ndarray, blue: ndarray) -> ndarray:
|
13 |
-
"""
|
14 |
-
Compute a 32-bits 2d digital elevation model from a nextzen 'terrarium' (terrain-rgb) raster.
|
15 |
-
'Terrarium' format PNG tiles contain raw elevation data in meters, in Mercator projection (EPSG:3857).
|
16 |
-
All values are positive with a 32,768 offset, split into the red, green, and blue channels,
|
17 |
-
with 16 bits of integer and 8 bits of fraction. To decode:
|
18 |
-
|
19 |
-
(red * 256 + green + blue / 256) - 32768
|
20 |
-
|
21 |
-
More details on https://www.mapzen.com/blog/elevation/
|
22 |
-
|
23 |
-
Args:
|
24 |
-
red: red-valued channel image array
|
25 |
-
green: green-valued channel image array
|
26 |
-
blue: blue-valued channel image array
|
27 |
-
|
28 |
-
Returns:
|
29 |
-
ndarray: nextzen 'terrarium' 2d digital elevation model raster at 32 bits
|
30 |
-
|
31 |
-
"""
|
32 |
-
return (red * 256 + green + blue / 256) - 32768
|
33 |
-
|
34 |
-
|
35 |
-
def get_mapbox__terrain_rgb_formula(red: ndarray, green: ndarray, blue: ndarray) -> ndarray:
|
36 |
-
return ((red * 256 * 256 + green * 256 + blue) * 0.1) - 10000
|
37 |
-
|
38 |
-
|
39 |
-
providers_terrain_rgb_formulas = {
|
40 |
-
XYZTerrainProvidersNames.MAPBOX_TERRAIN_TILES_NAME: get_mapbox__terrain_rgb_formula,
|
41 |
-
XYZTerrainProvidersNames.NEXTZEN_TERRAIN_TILES_NAME: get_nextzen_terrain_rgb_formula
|
42 |
-
}
|
43 |
-
|
44 |
-
|
45 |
-
def _get_2d_array_from_3d(arr: ndarray) -> ndarray:
|
46 |
-
return arr.reshape(arr.shape[0], arr.shape[1])
|
47 |
-
|
48 |
-
|
49 |
-
def _channel_split(arr: ndarray) -> list[ndarray]:
|
50 |
-
from numpy import dsplit
|
51 |
-
|
52 |
-
return dsplit(arr, arr.shape[-1])
|
53 |
-
|
54 |
-
|
55 |
-
def get_raster_terrain_rgb_like(arr: ndarray, xyz_provider_name, nan_value_int: int = -12000):
|
56 |
-
"""
|
57 |
-
Compute a 32-bits 2d digital elevation model from a terrain-rgb raster.
|
58 |
-
|
59 |
-
Args:
|
60 |
-
arr: rgb raster
|
61 |
-
xyz_provider_name: xyz provider
|
62 |
-
nan_value_int: threshold int value to replace NaN
|
63 |
-
|
64 |
-
Returns:
|
65 |
-
ndarray: 2d digital elevation model raster at 32 bits
|
66 |
-
"""
|
67 |
-
red, green, blue = _channel_split(arr)
|
68 |
-
dem_rgb = providers_terrain_rgb_formulas[xyz_provider_name](red, green, blue)
|
69 |
-
output = _get_2d_array_from_3d(dem_rgb)
|
70 |
-
output[output < nan_value_int] = np.NaN
|
71 |
-
return output
|
72 |
-
|
73 |
-
|
74 |
-
def get_rgb_prediction_image(raster_cropped: ndarray, slope_cellsize: int, invert_image: bool = True) -> ndarray:
|
75 |
-
"""
|
76 |
-
Return an RGB image from input numpy array
|
77 |
-
|
78 |
-
Args:
|
79 |
-
raster_cropped: input numpy array
|
80 |
-
slope_cellsize: window size to calculate slope and curvature (1st and 2nd degree array derivative)
|
81 |
-
invert_image:
|
82 |
-
|
83 |
-
Returns:
|
84 |
-
tuple of str: image filename, image path (with filename)
|
85 |
-
"""
|
86 |
-
from samgis_lisa_on_cuda.utilities.constants import CHANNEL_EXAGGERATIONS_LIST
|
87 |
-
|
88 |
-
try:
|
89 |
-
slope, curvature = get_slope_curvature(raster_cropped, slope_cellsize=slope_cellsize)
|
90 |
-
|
91 |
-
channel0 = raster_cropped
|
92 |
-
channel1 = normalize_array_list(
|
93 |
-
[raster_cropped, slope, curvature], CHANNEL_EXAGGERATIONS_LIST, title="channel1_normlist")
|
94 |
-
channel2 = curvature
|
95 |
-
|
96 |
-
return get_rgb_image(channel0, channel1, channel2, invert_image=invert_image)
|
97 |
-
except ValueError as ve_get_rgb_prediction_image:
|
98 |
-
msg = f"ve_get_rgb_prediction_image:{ve_get_rgb_prediction_image}."
|
99 |
-
app_logger.error(msg)
|
100 |
-
raise ve_get_rgb_prediction_image
|
101 |
-
|
102 |
-
|
103 |
-
def get_rgb_image(arr_channel0: ndarray, arr_channel1: ndarray, arr_channel2: ndarray,
|
104 |
-
invert_image: bool = True) -> ndarray:
|
105 |
-
"""
|
106 |
-
Return an RGB image from input R,G,B channel arrays
|
107 |
-
|
108 |
-
Args:
|
109 |
-
arr_channel0: channel image 0
|
110 |
-
arr_channel1: channel image 1
|
111 |
-
arr_channel2: channel image 2
|
112 |
-
invert_image: invert the RGB image channel order
|
113 |
-
|
114 |
-
Returns:
|
115 |
-
ndarray: RGB image
|
116 |
-
|
117 |
-
"""
|
118 |
-
try:
|
119 |
-
# RED curvature, GREEN slope, BLUE dem, invert_image=True
|
120 |
-
if len(arr_channel0.shape) != 2:
|
121 |
-
msg = f"arr_size, wrong type:{type(arr_channel0)} or arr_size:{arr_channel0.shape}."
|
122 |
-
app_logger.error(msg)
|
123 |
-
raise ValueError(msg)
|
124 |
-
data_rgb = np.zeros((arr_channel0.shape[0], arr_channel0.shape[1], 3), dtype=np.uint8)
|
125 |
-
app_logger.debug(f"arr_container data_rgb, type:{type(data_rgb)}, arr_shape:{data_rgb.shape}.")
|
126 |
-
data_rgb[:, :, 0] = normalize_array(
|
127 |
-
arr_channel0.astype(float), high=1, norm_type="float", title="RGB:channel0") * 64
|
128 |
-
data_rgb[:, :, 1] = normalize_array(
|
129 |
-
arr_channel1.astype(float), high=1, norm_type="float", title="RGB:channel1") * 128
|
130 |
-
data_rgb[:, :, 2] = normalize_array(
|
131 |
-
arr_channel2.astype(float), high=1, norm_type="float", title="RGB:channel2") * 192
|
132 |
-
if invert_image:
|
133 |
-
app_logger.debug(f"data_rgb:{type(data_rgb)}, {data_rgb.dtype}.")
|
134 |
-
data_rgb = bitwise_not(data_rgb)
|
135 |
-
return data_rgb
|
136 |
-
except ValueError as ve_get_rgb_image:
|
137 |
-
msg = f"ve_get_rgb_image:{ve_get_rgb_image}."
|
138 |
-
app_logger.error(msg)
|
139 |
-
raise ve_get_rgb_image
|
140 |
-
|
141 |
-
|
142 |
-
def get_slope_curvature(dem: ndarray, slope_cellsize: int, title: str = "") -> tuple[ndarray, ndarray]:
|
143 |
-
"""
|
144 |
-
Return a tuple of two numpy arrays representing slope and curvature (1st grade derivative and 2nd grade derivative)
|
145 |
-
|
146 |
-
Args:
|
147 |
-
dem: input numpy array
|
148 |
-
slope_cellsize: window size to calculate slope and curvature
|
149 |
-
title: array name
|
150 |
-
|
151 |
-
Returns:
|
152 |
-
tuple of ndarrays: slope image, curvature image
|
153 |
-
|
154 |
-
"""
|
155 |
-
|
156 |
-
app_logger.info(f"dem shape:{dem.shape}, slope_cellsize:{slope_cellsize}.")
|
157 |
-
|
158 |
-
try:
|
159 |
-
dem = dem.astype(float)
|
160 |
-
app_logger.debug("get_slope_curvature:: start")
|
161 |
-
slope = calculate_slope(dem, slope_cellsize)
|
162 |
-
app_logger.debug("get_slope_curvature:: created slope raster")
|
163 |
-
s2c = calculate_slope(slope, slope_cellsize)
|
164 |
-
curvature = normalize_array(s2c, norm_type="float", title=f"SC:curvature_{title}")
|
165 |
-
app_logger.debug("get_slope_curvature:: created curvature raster")
|
166 |
-
|
167 |
-
return slope, curvature
|
168 |
-
except ValueError as ve_get_slope_curvature:
|
169 |
-
msg = f"ve_get_slope_curvature:{ve_get_slope_curvature}."
|
170 |
-
app_logger.error(msg)
|
171 |
-
raise ve_get_slope_curvature
|
172 |
-
|
173 |
-
|
174 |
-
def calculate_slope(dem_array: ndarray, cell_size: int, calctype: str = "degree") -> ndarray:
|
175 |
-
"""
|
176 |
-
Return a numpy array representing slope (1st grade derivative)
|
177 |
-
|
178 |
-
Args:
|
179 |
-
dem_array: input numpy array
|
180 |
-
cell_size: window size to calculate slope
|
181 |
-
calctype: calculus type
|
182 |
-
|
183 |
-
Returns:
|
184 |
-
ndarray: slope image
|
185 |
-
|
186 |
-
"""
|
187 |
-
|
188 |
-
try:
|
189 |
-
gradx, grady = np.gradient(dem_array, cell_size)
|
190 |
-
dem_slope = np.sqrt(gradx ** 2 + grady ** 2)
|
191 |
-
if calctype == "degree":
|
192 |
-
dem_slope = np.degrees(np.arctan(dem_slope))
|
193 |
-
app_logger.debug(f"extracted slope with calctype:{calctype}.")
|
194 |
-
return dem_slope
|
195 |
-
except ValueError as ve_calculate_slope:
|
196 |
-
msg = f"ve_calculate_slope:{ve_calculate_slope}."
|
197 |
-
app_logger.error(msg)
|
198 |
-
raise ve_calculate_slope
|
199 |
-
|
200 |
-
|
201 |
-
def normalize_array(arr: ndarray, high: int = 255, norm_type: str = "float", invert: bool = False, title: str = "") -> ndarray:
|
202 |
-
"""
|
203 |
-
Return normalized numpy array between 0 and 'high' value. Default normalization type is int
|
204 |
-
|
205 |
-
Args:
|
206 |
-
arr: input numpy array
|
207 |
-
high: max value to use for normalization
|
208 |
-
norm_type: type of normalization: could be 'float' or 'int'
|
209 |
-
invert: bool to choose if invert the normalized numpy array
|
210 |
-
title: array title name
|
211 |
-
|
212 |
-
Returns:
|
213 |
-
ndarray: normalized numpy array
|
214 |
-
|
215 |
-
"""
|
216 |
-
np.seterr("raise")
|
217 |
-
|
218 |
-
h_min_arr = np.nanmin(arr)
|
219 |
-
h_arr_max = np.nanmax(arr)
|
220 |
-
try:
|
221 |
-
h_diff = h_arr_max - h_min_arr
|
222 |
-
app_logger.debug(
|
223 |
-
f"normalize_array:: '{title}',h_min_arr:{h_min_arr},h_arr_max:{h_arr_max},h_diff:{h_diff}, dtype:{arr.dtype}.")
|
224 |
-
except Exception as e_h_diff:
|
225 |
-
app_logger.error(f"e_h_diff:{e_h_diff}.")
|
226 |
-
raise ValueError(e_h_diff)
|
227 |
-
|
228 |
-
if check_empty_array(arr, high) or check_empty_array(arr, h_diff):
|
229 |
-
msg_ve = f"normalize_array::empty array '{title}',h_min_arr:{h_min_arr},h_arr_max:{h_arr_max},h_diff:{h_diff}, dtype:{arr.dtype}."
|
230 |
-
app_logger.error(msg_ve)
|
231 |
-
raise ValueError(msg_ve)
|
232 |
-
try:
|
233 |
-
normalized = high * (arr - h_min_arr) / h_diff
|
234 |
-
normalized = np.nanmax(normalized) - normalized if invert else normalized
|
235 |
-
return normalized.astype(int) if norm_type == "int" else normalized
|
236 |
-
except FloatingPointError as fe:
|
237 |
-
msg = f"normalize_array::{title}:h_arr_max:{h_arr_max},h_min_arr:{h_min_arr},fe:{fe}."
|
238 |
-
app_logger.error(msg)
|
239 |
-
raise ValueError(msg)
|
240 |
-
|
241 |
-
|
242 |
-
def normalize_array_list(arr_list: list[ndarray], exaggerations_list: list[float] = None, title: str = "") -> ndarray:
|
243 |
-
"""
|
244 |
-
Return a normalized numpy array from a list of numpy array and an optional list of exaggeration values.
|
245 |
-
|
246 |
-
Args:
|
247 |
-
arr_list: list of array to use for normalization
|
248 |
-
exaggerations_list: list of exaggeration values
|
249 |
-
title: array title name
|
250 |
-
|
251 |
-
Returns:
|
252 |
-
ndarray: normalized numpy array
|
253 |
-
|
254 |
-
"""
|
255 |
-
|
256 |
-
if not arr_list:
|
257 |
-
msg = f"input list can't be empty:{arr_list}."
|
258 |
-
app_logger.error(msg)
|
259 |
-
raise ValueError(msg)
|
260 |
-
if exaggerations_list is None:
|
261 |
-
exaggerations_list = list(np.ones(len(arr_list)))
|
262 |
-
arr_tmp = np.zeros(arr_list[0].shape)
|
263 |
-
for a, exaggeration in zip(arr_list, exaggerations_list):
|
264 |
-
app_logger.debug(f"normalize_array_list::exaggeration:{exaggeration}.")
|
265 |
-
arr_tmp += normalize_array(a, norm_type="float", title=f"ARRLIST:{title}.") * exaggeration
|
266 |
-
return arr_tmp / len(arr_list)
|
267 |
-
|
268 |
-
|
269 |
-
def check_empty_array(arr: ndarray, val: float) -> bool:
|
270 |
-
"""
|
271 |
-
Return True if the input numpy array is empy. Check if
|
272 |
-
- all values are all the same value (0, 1 or given 'val' input float value)
|
273 |
-
- all values that are not NaN are a given 'val' float value
|
274 |
-
|
275 |
-
Args:
|
276 |
-
arr: input numpy array
|
277 |
-
val: value to use for check if array is empty
|
278 |
-
|
279 |
-
Returns:
|
280 |
-
bool: True if the input numpy array is empty, False otherwise
|
281 |
-
|
282 |
-
"""
|
283 |
-
|
284 |
-
arr_check5_tmp = np.copy(arr)
|
285 |
-
arr_size = arr.shape[0]
|
286 |
-
arr_check3 = np.ones((arr_size, arr_size))
|
287 |
-
check1 = np.array_equal(arr, arr_check3)
|
288 |
-
check2 = np.array_equal(arr, np.zeros((arr_size, arr_size)))
|
289 |
-
arr_check3 *= val
|
290 |
-
check3 = np.array_equal(arr, arr_check3)
|
291 |
-
arr[np.isnan(arr)] = 0
|
292 |
-
check4 = np.array_equal(arr, np.zeros((arr_size, arr_size)))
|
293 |
-
arr_check5 = np.ones((arr_size, arr_size)) * val
|
294 |
-
arr_check5_tmp[np.isnan(arr_check5_tmp)] = val
|
295 |
-
check5 = np.array_equal(arr_check5_tmp, arr_check5)
|
296 |
-
app_logger.debug(f"array checks:{check1}, {check2}, {check3}, {check4}, {check5}.")
|
297 |
-
return check1 or check2 or check3 or check4 or check5
|
298 |
-
|
299 |
-
|
300 |
-
def write_raster_png(arr, transform, prefix: str, suffix: str, folder_output_path="/tmp"):
|
301 |
-
from pathlib import Path
|
302 |
-
from rasterio.plot import reshape_as_raster
|
303 |
-
|
304 |
-
output_filename = Path(folder_output_path) / f"{prefix}_{suffix}.png"
|
305 |
-
|
306 |
-
with rasterio_open(
|
307 |
-
output_filename, 'w', driver='PNG',
|
308 |
-
height=arr.shape[0],
|
309 |
-
width=arr.shape[1],
|
310 |
-
count=3,
|
311 |
-
dtype=str(arr.dtype),
|
312 |
-
crs=OUTPUT_CRS_STRING,
|
313 |
-
transform=transform) as dst:
|
314 |
-
dst.write(reshape_as_raster(arr))
|
315 |
-
app_logger.info(f"written:{output_filename} as PNG, use {OUTPUT_CRS_STRING} as CRS.")
|
316 |
-
|
317 |
-
|
318 |
-
def write_raster_tiff(arr, transform, prefix: str, suffix: str, folder_output_path="/tmp"):
|
319 |
-
from pathlib import Path
|
320 |
-
output_filename = Path(folder_output_path) / f"{prefix}_{suffix}.tiff"
|
321 |
-
|
322 |
-
with rasterio_open(
|
323 |
-
output_filename, 'w', driver='GTiff',
|
324 |
-
height=arr.shape[0],
|
325 |
-
width=arr.shape[1],
|
326 |
-
count=1,
|
327 |
-
dtype=str(arr.dtype),
|
328 |
-
crs=OUTPUT_CRS_STRING,
|
329 |
-
transform=transform) as dst:
|
330 |
-
dst.write(arr, 1)
|
331 |
-
app_logger.info(f"written:{output_filename} as TIFF, use {OUTPUT_CRS_STRING} as CRS.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/io/tms2geotiff.py
DELETED
@@ -1,181 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
from numpy import ndarray
|
4 |
-
from samgis_core.utilities.type_hints import TupleFloat
|
5 |
-
from xyzservices import TileProvider
|
6 |
-
|
7 |
-
from samgis_lisa_on_cuda import app_logger
|
8 |
-
from samgis_lisa_on_cuda.utilities.constants import (OUTPUT_CRS_STRING, DRIVER_RASTERIO_GTIFF, N_MAX_RETRIES, N_CONNECTION, N_WAIT,
|
9 |
-
ZOOM_AUTO, BOOL_USE_CACHE)
|
10 |
-
from samgis_lisa_on_cuda.utilities.type_hints import tuple_ndarray_transform
|
11 |
-
|
12 |
-
|
13 |
-
bool_use_cache = int(os.getenv("BOOL_USE_CACHE", BOOL_USE_CACHE))
|
14 |
-
n_connection = int(os.getenv("N_CONNECTION", N_CONNECTION))
|
15 |
-
n_max_retries = int(os.getenv("N_MAX_RETRIES", N_MAX_RETRIES))
|
16 |
-
n_wait = int(os.getenv("N_WAIT", N_WAIT))
|
17 |
-
zoom_auto_string = os.getenv("ZOOM_AUTO", ZOOM_AUTO)
|
18 |
-
|
19 |
-
|
20 |
-
def download_extent(w: float, s: float, e: float, n: float, zoom: int or str = zoom_auto_string,
|
21 |
-
source: TileProvider or str = None,
|
22 |
-
wait: int = n_wait, max_retries: int = n_max_retries, n_connections: int = n_connection,
|
23 |
-
use_cache: bool = bool_use_cache) -> tuple_ndarray_transform:
|
24 |
-
"""
|
25 |
-
Download, merge and crop a list of tiles into a single geo-referenced image or a raster geodata
|
26 |
-
|
27 |
-
Args:
|
28 |
-
w: West edge
|
29 |
-
s: South edge
|
30 |
-
e: East edge
|
31 |
-
n: North edge
|
32 |
-
zoom: Level of detail
|
33 |
-
source: The tile source: web tile provider or path to local file. The web tile provider can be in the form of
|
34 |
-
a :class:`xyzservices.TileProvider` object or a URL. The placeholders for the XYZ in the URL need to be
|
35 |
-
`{x}`, `{y}`, `{z}`, respectively. For local file paths, the file is read with `rasterio` and all bands are
|
36 |
-
loaded into the basemap. IMPORTANT: tiles are assumed to be in the Spherical Mercator projection
|
37 |
-
(EPSG:3857), unless the `crs` keyword is specified.
|
38 |
-
wait: if the tile API is rate-limited, the number of seconds to wait
|
39 |
-
between a failed request and the next try
|
40 |
-
max_retries: total number of rejected requests allowed before contextily will stop trying to fetch more tiles
|
41 |
-
from a rate-limited API.
|
42 |
-
n_connections: Number of connections for downloading tiles in parallel. Be careful not to overload the tile
|
43 |
-
server and to check the tile provider's terms of use before increasing this value. E.g., OpenStreetMap has
|
44 |
-
a max. value of 2 (https://operations.osmfoundation.org/policies/tiles/). If allowed to download in
|
45 |
-
parallel, a recommended value for n_connections is 16, and should never be larger than 64.
|
46 |
-
use_cache: If False, caching of the downloaded tiles will be disabled. This can be useful in resource
|
47 |
-
constrained environments, especially when using n_connections > 1, or when a tile provider's terms of use
|
48 |
-
don't allow caching.
|
49 |
-
|
50 |
-
Returns:
|
51 |
-
parsed request input
|
52 |
-
"""
|
53 |
-
try:
|
54 |
-
from samgis_lisa_on_cuda import contextily_tile
|
55 |
-
from samgis_lisa_on_cuda.io.coordinates_pixel_conversion import _from4326_to3857
|
56 |
-
|
57 |
-
app_logger.info(f"connection number:{n_connections}, type:{type(n_connections)}.")
|
58 |
-
app_logger.info(f"zoom:{zoom}, type:{type(zoom)}.")
|
59 |
-
app_logger.debug(f"download raster from source:{source} with bounding box w:{w}, s:{s}, e:{e}, n:{n}.")
|
60 |
-
app_logger.debug(f"types w:{type(w)}, s:{type(s)}, e:{type(e)}, n:{type(n)}.")
|
61 |
-
downloaded_raster, bbox_raster = contextily_tile.bounds2img(
|
62 |
-
w, s, e, n, zoom=zoom, source=source, ll=True, wait=wait, max_retries=max_retries,
|
63 |
-
n_connections=n_connections, use_cache=use_cache)
|
64 |
-
xp0, yp0 = _from4326_to3857(n, e)
|
65 |
-
xp1, yp1 = _from4326_to3857(s, w)
|
66 |
-
cropped_image_ndarray, cropped_transform = crop_raster(yp1, xp1, yp0, xp0, downloaded_raster, bbox_raster)
|
67 |
-
return cropped_image_ndarray, cropped_transform
|
68 |
-
except Exception as e_download_extent:
|
69 |
-
app_logger.exception(f"e_download_extent:{e_download_extent}.", exc_info=True)
|
70 |
-
raise e_download_extent
|
71 |
-
|
72 |
-
|
73 |
-
def crop_raster(w: float, s: float, e: float, n: float, raster: ndarray, raster_bbox: TupleFloat,
|
74 |
-
crs: str = OUTPUT_CRS_STRING, driver: str = DRIVER_RASTERIO_GTIFF) -> tuple_ndarray_transform:
|
75 |
-
"""
|
76 |
-
Crop a raster using given bounding box (w, s, e, n) values
|
77 |
-
|
78 |
-
Args:
|
79 |
-
w: cropping west edge
|
80 |
-
s: cropping south edge
|
81 |
-
e: cropping east edge
|
82 |
-
n: cropping north edge
|
83 |
-
raster: raster image to crop
|
84 |
-
raster_bbox: bounding box of raster to crop
|
85 |
-
crs: The coordinate reference system. Required in 'w' or 'w+' modes, it is ignored in 'r' or 'r+' modes.
|
86 |
-
driver: A short format driver name (e.g. "GTiff" or "JPEG") or a list of such names (see GDAL docs at
|
87 |
-
https://gdal.org/drivers/raster/index.html ). In 'w' or 'w+' modes a single name is required. In 'r' or 'r+'
|
88 |
-
modes the driver can usually be omitted. Registered drivers will be tried sequentially until a match is
|
89 |
-
found. When multiple drivers are available for a format such as JPEG2000, one of them can be selected by
|
90 |
-
using this keyword argument.
|
91 |
-
|
92 |
-
Returns:
|
93 |
-
cropped raster with its Affine transform
|
94 |
-
"""
|
95 |
-
try:
|
96 |
-
from rasterio.io import MemoryFile
|
97 |
-
from rasterio.mask import mask as rio_mask
|
98 |
-
from shapely.geometry import Polygon
|
99 |
-
from geopandas import GeoSeries
|
100 |
-
|
101 |
-
app_logger.debug(f"raster: type {type(raster)}, raster_ext:{type(raster_bbox)}, {raster_bbox}.")
|
102 |
-
img_to_save, transform = get_transform_raster(raster, raster_bbox)
|
103 |
-
img_height, img_width, number_bands = img_to_save.shape
|
104 |
-
# https://rasterio.readthedocs.io/en/latest/topics/memory-files.html
|
105 |
-
with MemoryFile() as rio_mem_file:
|
106 |
-
app_logger.debug("writing raster in-memory to crop it with rasterio.mask.mask()")
|
107 |
-
with rio_mem_file.open(
|
108 |
-
driver=driver,
|
109 |
-
height=img_height,
|
110 |
-
width=img_width,
|
111 |
-
count=number_bands,
|
112 |
-
dtype=str(img_to_save.dtype.name),
|
113 |
-
crs=crs,
|
114 |
-
transform=transform,
|
115 |
-
) as src_raster_rw:
|
116 |
-
for band in range(number_bands):
|
117 |
-
src_raster_rw.write(img_to_save[:, :, band], band + 1)
|
118 |
-
app_logger.debug("cropping raster in-memory with rasterio.mask.mask()")
|
119 |
-
with rio_mem_file.open() as src_raster_ro:
|
120 |
-
shapes_crop_polygon = Polygon([(n, e), (s, e), (s, w), (n, w), (n, e)])
|
121 |
-
shapes_crop = GeoSeries([shapes_crop_polygon])
|
122 |
-
app_logger.debug(f"cropping with polygon::{shapes_crop_polygon}.")
|
123 |
-
cropped_image, cropped_transform = rio_mask(src_raster_ro, shapes=shapes_crop, crop=True)
|
124 |
-
cropped_image_ndarray = reshape_as_image(cropped_image)
|
125 |
-
app_logger.info(f"cropped image::{cropped_image_ndarray.shape}.")
|
126 |
-
return cropped_image_ndarray, cropped_transform
|
127 |
-
except Exception as e_crop_raster:
|
128 |
-
try:
|
129 |
-
app_logger.error(f"raster type:{type(raster)}.")
|
130 |
-
app_logger.error(f"raster shape:{raster.shape}, dtype:{raster.dtype}.")
|
131 |
-
except Exception as e_shape_dtype:
|
132 |
-
app_logger.exception(f"raster shape or dtype not found:{e_shape_dtype}.", exc_info=True)
|
133 |
-
app_logger.exception(f"e_crop_raster:{e_crop_raster}.", exc_info=True)
|
134 |
-
raise e_crop_raster
|
135 |
-
|
136 |
-
|
137 |
-
def get_transform_raster(raster: ndarray, raster_bbox: TupleFloat) -> tuple_ndarray_transform:
|
138 |
-
"""
|
139 |
-
Convert the input raster image to RGB and extract the Affine
|
140 |
-
|
141 |
-
Args:
|
142 |
-
raster: raster image to geo-reference
|
143 |
-
raster_bbox: bounding box of raster to crop
|
144 |
-
|
145 |
-
Returns:
|
146 |
-
rgb raster image and its Affine transform
|
147 |
-
"""
|
148 |
-
try:
|
149 |
-
from rasterio.transform import from_origin
|
150 |
-
from numpy import array as np_array, linspace as np_linspace, uint8 as np_uint8
|
151 |
-
from PIL.Image import fromarray
|
152 |
-
|
153 |
-
app_logger.debug(f"raster: type {type(raster)}, raster_ext:{type(raster_bbox)}, {raster_bbox}.")
|
154 |
-
rgb = fromarray(np_uint8(raster)).convert('RGB')
|
155 |
-
np_rgb = np_array(rgb)
|
156 |
-
img_height, img_width, _ = np_rgb.shape
|
157 |
-
|
158 |
-
min_x, max_x, min_y, max_y = raster_bbox
|
159 |
-
app_logger.debug(f"raster rgb shape:{np_rgb.shape}, raster rgb bbox {raster_bbox}.")
|
160 |
-
x = np_linspace(min_x, max_x, img_width)
|
161 |
-
y = np_linspace(min_y, max_y, img_height)
|
162 |
-
res_x = (x[-1] - x[0]) / img_width
|
163 |
-
res_y = (y[-1] - y[0]) / img_height
|
164 |
-
transform = from_origin(x[0] - res_x / 2, y[-1] + res_y / 2, res_x, res_y)
|
165 |
-
return np_rgb, transform
|
166 |
-
except Exception as e_get_transform_raster:
|
167 |
-
app_logger.error(f"arguments raster: {type(raster)}, {raster}.")
|
168 |
-
app_logger.error(f"arguments raster_bbox: {type(raster_bbox)}, {raster_bbox}.")
|
169 |
-
app_logger.exception(f"e_get_transform_raster:{e_get_transform_raster}.", exc_info=True)
|
170 |
-
raise e_get_transform_raster
|
171 |
-
|
172 |
-
|
173 |
-
def reshape_as_image(arr):
|
174 |
-
try:
|
175 |
-
from numpy import swapaxes
|
176 |
-
|
177 |
-
return swapaxes(swapaxes(arr, 0, 2), 0, 1)
|
178 |
-
except Exception as e_reshape_as_image:
|
179 |
-
app_logger.error(f"arguments: {type(arr)}, {arr}.")
|
180 |
-
app_logger.exception(f"e_reshape_as_image:{e_reshape_as_image}.", exc_info=True)
|
181 |
-
raise e_reshape_as_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/io/wrappers_helpers.py
DELETED
@@ -1,262 +0,0 @@
|
|
1 |
-
"""lambda helper functions"""
|
2 |
-
import logging
|
3 |
-
from sys import stdout
|
4 |
-
from typing import Dict
|
5 |
-
|
6 |
-
import loguru
|
7 |
-
from xyzservices import providers, TileProvider
|
8 |
-
|
9 |
-
from lisa_on_cuda.utils.app_helpers import get_cleaned_input
|
10 |
-
from samgis_lisa_on_cuda import app_logger
|
11 |
-
from samgis_lisa_on_cuda.io.coordinates_pixel_conversion import get_latlng_to_pixel_coordinates
|
12 |
-
from samgis_lisa_on_cuda.utilities.constants import COMPLETE_URL_TILES_MAPBOX, COMPLETE_URL_TILES_NEXTZEN, CUSTOM_RESPONSE_MESSAGES
|
13 |
-
from samgis_lisa_on_cuda.utilities.type_hints import ApiRequestBody, ContentTypes, XYZTerrainProvidersNames, \
|
14 |
-
XYZDefaultProvidersNames, StringPromptApiRequestBody
|
15 |
-
from samgis_core.utilities.utilities import base64_decode
|
16 |
-
|
17 |
-
|
18 |
-
def get_response(status: int, start_time: float, request_id: str, response_body: Dict = None) -> str:
|
19 |
-
"""
|
20 |
-
Response composer
|
21 |
-
|
22 |
-
Args:
|
23 |
-
status: status response
|
24 |
-
start_time: request start time (float)
|
25 |
-
request_id: str
|
26 |
-
response_body: dict we embed into our response
|
27 |
-
|
28 |
-
Returns:
|
29 |
-
json response
|
30 |
-
|
31 |
-
"""
|
32 |
-
from json import dumps
|
33 |
-
from time import time
|
34 |
-
|
35 |
-
app_logger.debug(f"response_body:{response_body}.")
|
36 |
-
response_body["duration_run"] = time() - start_time
|
37 |
-
response_body["message"] = CUSTOM_RESPONSE_MESSAGES[status]
|
38 |
-
response_body["request_id"] = request_id
|
39 |
-
|
40 |
-
response = {
|
41 |
-
"statusCode": status,
|
42 |
-
"header": {"Content-Type": ContentTypes.APPLICATION_JSON},
|
43 |
-
"body": dumps(response_body),
|
44 |
-
"isBase64Encoded": False
|
45 |
-
}
|
46 |
-
app_logger.debug(f"response type:{type(response)} => {response}.")
|
47 |
-
return dumps(response)
|
48 |
-
|
49 |
-
|
50 |
-
def get_parsed_bbox_points_with_string_prompt(request_input: StringPromptApiRequestBody) -> Dict:
|
51 |
-
"""
|
52 |
-
Parse the raw input request into bbox, prompt string and zoom
|
53 |
-
|
54 |
-
Args:
|
55 |
-
request_input: input dict
|
56 |
-
|
57 |
-
Returns:
|
58 |
-
dict with bounding box, prompt string and zoom
|
59 |
-
"""
|
60 |
-
|
61 |
-
app_logger.info(f"try to parsing input request {request_input}...")
|
62 |
-
|
63 |
-
bbox = request_input.bbox
|
64 |
-
app_logger.debug(f"request bbox: {type(bbox)}, value:{bbox}.")
|
65 |
-
ne = bbox.ne
|
66 |
-
sw = bbox.sw
|
67 |
-
app_logger.debug(f"request ne: {type(ne)}, value:{ne}.")
|
68 |
-
app_logger.debug(f"request sw: {type(sw)}, value:{sw}.")
|
69 |
-
ne_latlng = [float(ne.lat), float(ne.lng)]
|
70 |
-
sw_latlng = [float(sw.lat), float(sw.lng)]
|
71 |
-
new_zoom = int(request_input.zoom)
|
72 |
-
cleaned_prompt = get_cleaned_input(request_input.string_prompt)
|
73 |
-
|
74 |
-
app_logger.debug(f"bbox => {bbox}.")
|
75 |
-
app_logger.debug(f'request_input-prompt cleaned => {cleaned_prompt}.')
|
76 |
-
|
77 |
-
app_logger.info("unpacking elaborated request...")
|
78 |
-
return {
|
79 |
-
"bbox": [ne_latlng, sw_latlng],
|
80 |
-
"prompt": cleaned_prompt,
|
81 |
-
"zoom": new_zoom,
|
82 |
-
"source": get_url_tile(request_input.source_type)
|
83 |
-
}
|
84 |
-
|
85 |
-
|
86 |
-
def get_parsed_bbox_points_with_dictlist_prompt(request_input: ApiRequestBody) -> Dict:
|
87 |
-
"""
|
88 |
-
Parse the raw input request into bbox, prompt and zoom
|
89 |
-
|
90 |
-
Args:
|
91 |
-
request_input: input dict
|
92 |
-
|
93 |
-
Returns:
|
94 |
-
dict with bounding box, prompt and zoom
|
95 |
-
"""
|
96 |
-
|
97 |
-
app_logger.info(f"try to parsing input request {request_input}...")
|
98 |
-
|
99 |
-
bbox = request_input.bbox
|
100 |
-
app_logger.debug(f"request bbox: {type(bbox)}, value:{bbox}.")
|
101 |
-
ne = bbox.ne
|
102 |
-
sw = bbox.sw
|
103 |
-
app_logger.debug(f"request ne: {type(ne)}, value:{ne}.")
|
104 |
-
app_logger.debug(f"request sw: {type(sw)}, value:{sw}.")
|
105 |
-
ne_latlng = [float(ne.lat), float(ne.lng)]
|
106 |
-
sw_latlng = [float(sw.lat), float(sw.lng)]
|
107 |
-
new_zoom = int(request_input.zoom)
|
108 |
-
new_prompt_list = _get_parsed_prompt_list(ne, sw, new_zoom, request_input.prompt)
|
109 |
-
|
110 |
-
app_logger.debug(f"bbox => {bbox}.")
|
111 |
-
app_logger.debug(f'request_input-prompt updated => {new_prompt_list}.')
|
112 |
-
|
113 |
-
app_logger.info("unpacking elaborated request...")
|
114 |
-
return {
|
115 |
-
"bbox": [ne_latlng, sw_latlng],
|
116 |
-
"prompt": new_prompt_list,
|
117 |
-
"zoom": new_zoom,
|
118 |
-
"source": get_url_tile(request_input.source_type)
|
119 |
-
}
|
120 |
-
|
121 |
-
|
122 |
-
def _get_parsed_prompt_list(bbox_ne, bbox_sw, zoom, prompt_list):
|
123 |
-
new_prompt_list = []
|
124 |
-
for prompt in prompt_list:
|
125 |
-
app_logger.debug(f"current prompt: {type(prompt)}, value:{prompt}.")
|
126 |
-
new_prompt = {"type": prompt.type.value}
|
127 |
-
if prompt.type == "point":
|
128 |
-
new_prompt_data = _get_new_prompt_data_point(bbox_ne, bbox_sw, prompt, zoom)
|
129 |
-
new_prompt["label"] = prompt.label.value
|
130 |
-
elif prompt.type == "rectangle":
|
131 |
-
new_prompt_data = _get_new_prompt_data_rectangle(bbox_ne, bbox_sw, prompt, zoom)
|
132 |
-
else:
|
133 |
-
msg = "Valid prompt type: 'point' or 'rectangle', not '{}'. Check ApiRequestBody parsing/validation."
|
134 |
-
raise TypeError(msg.format(prompt.type))
|
135 |
-
app_logger.debug(f"new_prompt_data: {type(new_prompt_data)}, value:{new_prompt_data}.")
|
136 |
-
new_prompt["data"] = new_prompt_data
|
137 |
-
new_prompt_list.append(new_prompt)
|
138 |
-
return new_prompt_list
|
139 |
-
|
140 |
-
|
141 |
-
def _get_new_prompt_data_point(bbox_ne, bbox_sw, prompt, zoom):
|
142 |
-
current_point = get_latlng_to_pixel_coordinates(bbox_ne, bbox_sw, prompt.data, zoom, prompt.type)
|
143 |
-
app_logger.debug(f"current prompt: {type(current_point)}, value:{current_point}, label: {prompt.label}.")
|
144 |
-
return [current_point['x'], current_point['y']]
|
145 |
-
|
146 |
-
|
147 |
-
def _get_new_prompt_data_rectangle(bbox_ne, bbox_sw, prompt, zoom):
|
148 |
-
current_point_ne = get_latlng_to_pixel_coordinates(bbox_ne, bbox_sw, prompt.data.ne, zoom, prompt.type)
|
149 |
-
app_logger.debug(
|
150 |
-
f"rectangle:: current_point_ne prompt: {type(current_point_ne)}, value:{current_point_ne}.")
|
151 |
-
current_point_sw = get_latlng_to_pixel_coordinates(bbox_ne, bbox_sw, prompt.data.sw, zoom, prompt.type)
|
152 |
-
app_logger.debug(
|
153 |
-
f"rectangle:: current_point_sw prompt: {type(current_point_sw)}, value:{current_point_sw}.")
|
154 |
-
# correct order for rectangle prompt
|
155 |
-
return [
|
156 |
-
current_point_sw["x"],
|
157 |
-
current_point_ne["y"],
|
158 |
-
current_point_ne["x"],
|
159 |
-
current_point_sw["y"]
|
160 |
-
]
|
161 |
-
|
162 |
-
|
163 |
-
def get_parsed_request_body(event: Dict or str) -> ApiRequestBody:
|
164 |
-
"""
|
165 |
-
Validator for the raw input request lambda event
|
166 |
-
|
167 |
-
Args:
|
168 |
-
event: input dict
|
169 |
-
|
170 |
-
Returns:
|
171 |
-
parsed request input
|
172 |
-
"""
|
173 |
-
from json import dumps, loads
|
174 |
-
from logging import getLevelName
|
175 |
-
|
176 |
-
def _get_current_log_level(logger: loguru.logger) -> [str, loguru._logger.Level]:
|
177 |
-
levels = logger._core.levels
|
178 |
-
current_log_level = logger._core.min_level
|
179 |
-
level_filt = [l for l in levels.items() if l[1].no == current_log_level]
|
180 |
-
return level_filt[0]
|
181 |
-
|
182 |
-
app_logger.info(f"event:{dumps(event)}...")
|
183 |
-
try:
|
184 |
-
raw_body = event["body"]
|
185 |
-
except Exception as e_constants1:
|
186 |
-
app_logger.error(f"e_constants1:{e_constants1}.")
|
187 |
-
raw_body = event
|
188 |
-
app_logger.debug(f"raw_body, #1: {type(raw_body)}, {raw_body}...")
|
189 |
-
if isinstance(raw_body, str):
|
190 |
-
body_decoded_str = base64_decode(raw_body)
|
191 |
-
app_logger.debug(f"body_decoded_str: {type(body_decoded_str)}, {body_decoded_str}...")
|
192 |
-
raw_body = loads(body_decoded_str)
|
193 |
-
app_logger.info(f"body, #2: {type(raw_body)}, {raw_body}...")
|
194 |
-
|
195 |
-
parsed_body = ApiRequestBody.model_validate(raw_body)
|
196 |
-
log_level = "DEBUG" if parsed_body.debug else "INFO"
|
197 |
-
app_logger.remove()
|
198 |
-
app_logger.add(stdout, level=log_level)
|
199 |
-
try:
|
200 |
-
current_log_level_name, _ = _get_current_log_level(app_logger)
|
201 |
-
app_logger.warning(f"set log level to {getLevelName(current_log_level_name)}.")
|
202 |
-
except Exception as ex:
|
203 |
-
print("failing setting parsing bbox, logger is ok? ex:", ex, "#")
|
204 |
-
|
205 |
-
return parsed_body
|
206 |
-
|
207 |
-
|
208 |
-
mapbox_terrain_rgb = TileProvider(
|
209 |
-
name=XYZTerrainProvidersNames.MAPBOX_TERRAIN_TILES_NAME,
|
210 |
-
url=COMPLETE_URL_TILES_MAPBOX,
|
211 |
-
attribution=""
|
212 |
-
)
|
213 |
-
nextzen_terrain_rgb = TileProvider(
|
214 |
-
name=XYZTerrainProvidersNames.NEXTZEN_TERRAIN_TILES_NAME,
|
215 |
-
url=COMPLETE_URL_TILES_NEXTZEN,
|
216 |
-
attribution=""
|
217 |
-
)
|
218 |
-
|
219 |
-
|
220 |
-
def get_url_tile(source_type: str):
|
221 |
-
try:
|
222 |
-
match source_type.lower():
|
223 |
-
case XYZDefaultProvidersNames.DEFAULT_TILES_NAME_SHORT:
|
224 |
-
return providers.query_name(XYZDefaultProvidersNames.DEFAULT_TILES_NAME)
|
225 |
-
case XYZTerrainProvidersNames.MAPBOX_TERRAIN_TILES_NAME:
|
226 |
-
return mapbox_terrain_rgb
|
227 |
-
case XYZTerrainProvidersNames.NEXTZEN_TERRAIN_TILES_NAME:
|
228 |
-
app_logger.info("nextzen_terrain_rgb:", nextzen_terrain_rgb)
|
229 |
-
return nextzen_terrain_rgb
|
230 |
-
case _:
|
231 |
-
return providers.query_name(source_type)
|
232 |
-
except ValueError as ve:
|
233 |
-
from pydantic_core import ValidationError
|
234 |
-
|
235 |
-
app_logger.error("ve:", str(ve))
|
236 |
-
raise ValidationError(ve)
|
237 |
-
|
238 |
-
|
239 |
-
def check_source_type_is_terrain(source: str | TileProvider):
|
240 |
-
return isinstance(source, TileProvider) and source.name in list(XYZTerrainProvidersNames)
|
241 |
-
|
242 |
-
|
243 |
-
def get_source_name(source: str | TileProvider) -> str | bool:
|
244 |
-
try:
|
245 |
-
match source.lower():
|
246 |
-
case XYZDefaultProvidersNames.DEFAULT_TILES_NAME_SHORT:
|
247 |
-
source_output = providers.query_name(XYZDefaultProvidersNames.DEFAULT_TILES_NAME)
|
248 |
-
case _:
|
249 |
-
source_output = providers.query_name(source)
|
250 |
-
if isinstance(source_output, str):
|
251 |
-
return source_output
|
252 |
-
try:
|
253 |
-
source_dict = dict(source_output)
|
254 |
-
app_logger.info(f"source_dict:{type(source_dict)}, {'name' in source_dict}, source_dict:{source_dict}.")
|
255 |
-
return source_dict["name"]
|
256 |
-
except KeyError as ke:
|
257 |
-
app_logger.error(f"ke:{ke}.")
|
258 |
-
except ValueError as ve:
|
259 |
-
app_logger.info(f"source name::{source}, ve:{ve}.")
|
260 |
-
app_logger.info(f"source name::{source}.")
|
261 |
-
|
262 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/prediction_api/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
"""functions useful to handle machine learning models"""
|
|
|
|
samgis_lisa_on_cuda/prediction_api/global_models.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
models_dict = {
|
2 |
-
"fastsam": {"instance": None},
|
3 |
-
"lisa": {"inference": None}
|
4 |
-
}
|
5 |
-
embedding_dict = {}
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/prediction_api/lisa.py
DELETED
@@ -1,77 +0,0 @@
|
|
1 |
-
from datetime import datetime
|
2 |
-
|
3 |
-
from lisa_on_cuda.utils import app_helpers
|
4 |
-
from samgis_core.utilities.type_hints import LlistFloat, DictStrInt
|
5 |
-
from samgis_lisa_on_cuda import app_logger
|
6 |
-
from samgis_lisa_on_cuda.io.geo_helpers import get_vectorized_raster_as_geojson
|
7 |
-
from samgis_lisa_on_cuda.io.raster_helpers import write_raster_png, write_raster_tiff
|
8 |
-
from samgis_lisa_on_cuda.io.tms2geotiff import download_extent
|
9 |
-
from samgis_lisa_on_cuda.prediction_api.global_models import models_dict
|
10 |
-
from samgis_lisa_on_cuda.utilities.constants import DEFAULT_URL_TILES
|
11 |
-
|
12 |
-
msg_write_tmp_on_disk = "found option to write images and geojson output..."
|
13 |
-
|
14 |
-
|
15 |
-
def lisa_predict(
|
16 |
-
bbox: LlistFloat,
|
17 |
-
prompt: str,
|
18 |
-
zoom: float,
|
19 |
-
inference_function_name_key: str = "lisa",
|
20 |
-
source: str = DEFAULT_URL_TILES,
|
21 |
-
source_name: str = None
|
22 |
-
) -> DictStrInt:
|
23 |
-
"""
|
24 |
-
Return predictions as a geojson from a geo-referenced image using the given input prompt.
|
25 |
-
|
26 |
-
1. if necessary instantiate a segment anything machine learning instance model
|
27 |
-
2. download a geo-referenced raster image delimited by the coordinates bounding box (bbox)
|
28 |
-
3. get a prediction image from the segment anything instance model using the input prompt
|
29 |
-
4. get a geo-referenced geojson from the prediction image
|
30 |
-
|
31 |
-
Args:
|
32 |
-
bbox: coordinates bounding box
|
33 |
-
prompt: machine learning input prompt
|
34 |
-
zoom: Level of detail
|
35 |
-
inference_function_name_key: machine learning model name
|
36 |
-
source: xyz
|
37 |
-
source_name: name of tile provider
|
38 |
-
|
39 |
-
Returns:
|
40 |
-
Affine transform
|
41 |
-
"""
|
42 |
-
from os import getenv
|
43 |
-
|
44 |
-
app_logger.info("start lisa inference...")
|
45 |
-
if models_dict[inference_function_name_key]["inference"] is None:
|
46 |
-
app_logger.info(f"missing inference function {inference_function_name_key}, instantiating it now!")
|
47 |
-
parsed_args = app_helpers.parse_args([])
|
48 |
-
inference_fn = app_helpers.get_inference_model_by_args(parsed_args)
|
49 |
-
models_dict[inference_function_name_key]["inference"] = inference_fn
|
50 |
-
app_logger.debug(f"using a {inference_function_name_key} instance model...")
|
51 |
-
inference_fn = models_dict[inference_function_name_key]["inference"]
|
52 |
-
|
53 |
-
pt0, pt1 = bbox
|
54 |
-
app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
|
55 |
-
img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=source)
|
56 |
-
app_logger.info(
|
57 |
-
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
|
58 |
-
folder_write_tmp_on_disk = getenv("WRITE_TMP_ON_DISK", "")
|
59 |
-
prefix = f"w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}_"
|
60 |
-
if bool(folder_write_tmp_on_disk):
|
61 |
-
now = datetime.now().strftime('%Y%m%d_%H%M%S')
|
62 |
-
app_logger.info(msg_write_tmp_on_disk + f"with coords {prefix}, shape:{img.shape}, {len(img.shape)}.")
|
63 |
-
if img.shape and len(img.shape) == 2:
|
64 |
-
write_raster_tiff(img, transform, f"{source_name}_{prefix}_{now}_", f"raw_tiff", folder_write_tmp_on_disk)
|
65 |
-
if img.shape and len(img.shape) == 3 and img.shape[2] == 3:
|
66 |
-
write_raster_png(img, transform, f"{source_name}_{prefix}_{now}_", f"raw_img", folder_write_tmp_on_disk)
|
67 |
-
else:
|
68 |
-
app_logger.info("keep all temp data in memory...")
|
69 |
-
|
70 |
-
app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.")
|
71 |
-
embedding_key = f"{source_name}_z{zoom}_{prefix}"
|
72 |
-
_, mask, output_string = inference_fn(prompt, img, app_logger, embedding_key)
|
73 |
-
# app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
|
74 |
-
return {
|
75 |
-
"output_string": output_string,
|
76 |
-
**get_vectorized_raster_as_geojson(mask, transform)
|
77 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/prediction_api/predictors.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
"""functions using machine learning instance model(s)"""
|
2 |
-
from samgis_lisa_on_cuda import app_logger, MODEL_FOLDER
|
3 |
-
from samgis_lisa_on_cuda.io.geo_helpers import get_vectorized_raster_as_geojson
|
4 |
-
from samgis_lisa_on_cuda.io.raster_helpers import get_raster_terrain_rgb_like, get_rgb_prediction_image
|
5 |
-
from samgis_lisa_on_cuda.io.tms2geotiff import download_extent
|
6 |
-
from samgis_lisa_on_cuda.io.wrappers_helpers import check_source_type_is_terrain
|
7 |
-
from samgis_lisa_on_cuda.prediction_api.global_models import models_dict, embedding_dict
|
8 |
-
from samgis_lisa_on_cuda.utilities.constants import DEFAULT_URL_TILES, SLOPE_CELLSIZE
|
9 |
-
from samgis_core.prediction_api.sam_onnx import SegmentAnythingONNX, get_raster_inference_with_embedding_from_dict
|
10 |
-
from samgis_core.utilities.constants import MODEL_ENCODER_NAME, MODEL_DECODER_NAME, DEFAULT_INPUT_SHAPE
|
11 |
-
from samgis_core.utilities.type_hints import LlistFloat, DictStrInt, ListDict
|
12 |
-
|
13 |
-
|
14 |
-
def samexporter_predict(
|
15 |
-
bbox: LlistFloat,
|
16 |
-
prompt: ListDict,
|
17 |
-
zoom: float,
|
18 |
-
model_name_key: str = "fastsam",
|
19 |
-
source: str = DEFAULT_URL_TILES,
|
20 |
-
source_name: str = None
|
21 |
-
) -> DictStrInt:
|
22 |
-
"""
|
23 |
-
Return predictions as a geojson from a geo-referenced image using the given input prompt.
|
24 |
-
|
25 |
-
1. if necessary instantiate a segment anything machine learning instance model
|
26 |
-
2. download a geo-referenced raster image delimited by the coordinates bounding box (bbox)
|
27 |
-
3. get a prediction image from the segment anything instance model using the input prompt
|
28 |
-
4. get a geo-referenced geojson from the prediction image
|
29 |
-
|
30 |
-
Args:
|
31 |
-
bbox: coordinates bounding box
|
32 |
-
prompt: machine learning input prompt
|
33 |
-
zoom: Level of detail
|
34 |
-
model_name_key: machine learning model name
|
35 |
-
source: xyz
|
36 |
-
source_name: name of tile provider
|
37 |
-
|
38 |
-
Returns:
|
39 |
-
Affine transform
|
40 |
-
"""
|
41 |
-
if models_dict[model_name_key]["instance"] is None:
|
42 |
-
app_logger.info(f"missing instance model {model_name_key}, instantiating it now!")
|
43 |
-
model_instance = SegmentAnythingONNX(
|
44 |
-
encoder_model_path=MODEL_FOLDER / MODEL_ENCODER_NAME,
|
45 |
-
decoder_model_path=MODEL_FOLDER / MODEL_DECODER_NAME
|
46 |
-
)
|
47 |
-
models_dict[model_name_key]["instance"] = model_instance
|
48 |
-
app_logger.debug(f"using a {model_name_key} instance model...")
|
49 |
-
models_instance = models_dict[model_name_key]["instance"]
|
50 |
-
|
51 |
-
pt0, pt1 = bbox
|
52 |
-
app_logger.info(f"tile_source: {source}: downloading geo-referenced raster with bbox {bbox}, zoom {zoom}.")
|
53 |
-
img, transform = download_extent(w=pt1[1], s=pt1[0], e=pt0[1], n=pt0[0], zoom=zoom, source=source)
|
54 |
-
if check_source_type_is_terrain(source):
|
55 |
-
app_logger.info("terrain-rgb like raster: transforms it into a DEM")
|
56 |
-
dem = get_raster_terrain_rgb_like(img, source.name)
|
57 |
-
# set a slope cell size proportional to the image width
|
58 |
-
slope_cellsize = int(img.shape[1] * SLOPE_CELLSIZE / DEFAULT_INPUT_SHAPE[1])
|
59 |
-
app_logger.info(f"terrain-rgb like raster: compute slope, curvature using {slope_cellsize} as cell size.")
|
60 |
-
img = get_rgb_prediction_image(dem, slope_cellsize)
|
61 |
-
|
62 |
-
app_logger.info(
|
63 |
-
f"img type {type(img)} with shape/size:{img.size}, transform type: {type(transform)}, transform:{transform}.")
|
64 |
-
app_logger.info(f"source_name:{source_name}, source_name type:{type(source_name)}.")
|
65 |
-
embedding_key = f"{source_name}_z{zoom}_w{pt1[1]},s{pt1[0]},e{pt0[1]},n{pt0[0]}"
|
66 |
-
mask, n_predictions = get_raster_inference_with_embedding_from_dict(
|
67 |
-
img, prompt, models_instance, model_name_key, embedding_key, embedding_dict)
|
68 |
-
app_logger.info(f"created {n_predictions} masks, preparing conversion to geojson...")
|
69 |
-
return {
|
70 |
-
"n_predictions": n_predictions,
|
71 |
-
**get_vectorized_raster_as_geojson(mask, transform)
|
72 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/utilities/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
"""various helpers functions"""
|
|
|
|
samgis_lisa_on_cuda/utilities/constants.py
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
"""Project constants"""
|
2 |
-
INPUT_CRS_STRING = "EPSG:4326"
|
3 |
-
OUTPUT_CRS_STRING = "EPSG:3857"
|
4 |
-
DRIVER_RASTERIO_GTIFF = "GTiff"
|
5 |
-
CUSTOM_RESPONSE_MESSAGES = {
|
6 |
-
200: "ok",
|
7 |
-
400: "Bad Request",
|
8 |
-
422: "Missing required parameter",
|
9 |
-
500: "Internal server error"
|
10 |
-
}
|
11 |
-
TILE_SIZE = 256
|
12 |
-
EARTH_EQUATORIAL_RADIUS = 6378137.0
|
13 |
-
WKT_3857 = 'PROJCS["WGS 84 / Pseudo-Mercator",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,'
|
14 |
-
WKT_3857 += 'AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],'
|
15 |
-
WKT_3857 += 'UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],'
|
16 |
-
WKT_3857 += 'PROJECTION["Mercator_1SP"],PARAMETER["central_meridian",0],PARAMETER["scale_factor",1],'
|
17 |
-
WKT_3857 += 'PARAMETER["false_easting",0],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],'
|
18 |
-
WKT_3857 += 'AXIS["X",EAST],AXIS["Y",NORTH],EXTENSION["PROJ4","+proj=merc +a=6378137 +b=6378137 +lat_ts=0.0 +lon_0=0.0 '
|
19 |
-
WKT_3857 += '+x_0=0.0 +y_0=0 +k=1.0 +units=m +nadgrids=@null +wktext +no_defs"],AUTHORITY["EPSG","3857"]]'
|
20 |
-
SERVICE_NAME = "sam-gis"
|
21 |
-
DEFAULT_LOG_LEVEL = 'INFO'
|
22 |
-
RETRY_DOWNLOAD = 3
|
23 |
-
TIMEOUT_DOWNLOAD = 60
|
24 |
-
CALLBACK_INTERVAL_DOWNLOAD = 0.05
|
25 |
-
BOOL_USE_CACHE = True
|
26 |
-
N_WAIT = 0
|
27 |
-
N_MAX_RETRIES = 2
|
28 |
-
N_CONNECTION = 2
|
29 |
-
ZOOM_AUTO = "auto"
|
30 |
-
DEFAULT_URL_TILES = 'https://tile.openstreetmap.org/{z}/{x}/{y}.png'
|
31 |
-
DOMAIN_URL_TILES_MAPBOX = "api.mapbox.com"
|
32 |
-
RELATIVE_URL_TILES_MAPBOX = "v/mapbox.terrain-rgb/{zoom}/{x}/{y}{@2x}.pngraw?access_token={TOKEN}"
|
33 |
-
COMPLETE_URL_TILES_MAPBOX = f"https://{DOMAIN_URL_TILES_MAPBOX}/{RELATIVE_URL_TILES_MAPBOX}"
|
34 |
-
# https://s3.amazonaws.com/elevation-tiles-prod/terrarium/13/1308/3167.png
|
35 |
-
DOMAIN_URL_TILES_NEXTZEN = "s3.amazonaws.com"
|
36 |
-
RELATIVE_URL_TILES_NEXTZEN = "elevation-tiles-prod/terrarium/{z}/{x}/{y}.png" # "terrarium/{z}/{x}/{y}.png"
|
37 |
-
COMPLETE_URL_TILES_NEXTZEN = f"https://{DOMAIN_URL_TILES_NEXTZEN}/{RELATIVE_URL_TILES_NEXTZEN}"
|
38 |
-
CHANNEL_EXAGGERATIONS_LIST = [2.5, 1.1, 2.0]
|
39 |
-
SLOPE_CELLSIZE = 61
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/utilities/session_logger.py
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
import contextvars
|
2 |
-
import logging
|
3 |
-
from functools import wraps
|
4 |
-
from typing import Callable, Tuple
|
5 |
-
|
6 |
-
logging_uuid = contextvars.ContextVar("uuid")
|
7 |
-
default_formatter = '%(asctime)s | %(uuid)s [%(pathname)s:%(module)s %(lineno)d] %(levelname)s | %(message)s'
|
8 |
-
|
9 |
-
|
10 |
-
loggingType = logging.CRITICAL | logging.ERROR | logging.WARNING | logging.INFO | logging.DEBUG
|
11 |
-
|
12 |
-
|
13 |
-
def setup_logging(
|
14 |
-
debug: bool = False, formatter: str = default_formatter, name: str = "logger"
|
15 |
-
) -> Tuple[logging, contextvars.ContextVar]:
|
16 |
-
"""
|
17 |
-
Create a logging instance with log string formatter.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
debug: logging debug argument
|
21 |
-
formatter: log string formatter
|
22 |
-
name: logger name
|
23 |
-
|
24 |
-
Returns:
|
25 |
-
Logger
|
26 |
-
|
27 |
-
"""
|
28 |
-
|
29 |
-
old_factory = logging.getLogRecordFactory()
|
30 |
-
|
31 |
-
def record_factory(*args, **kwargs):
|
32 |
-
record = old_factory(*args, **kwargs)
|
33 |
-
record.uuid = logging_uuid.get("uuid")
|
34 |
-
if isinstance(record.msg, str):
|
35 |
-
record.msg = record.msg.replace("\\", "\\\\").replace("\n", "\\n")
|
36 |
-
return record
|
37 |
-
|
38 |
-
logging.setLogRecordFactory(record_factory)
|
39 |
-
logging.basicConfig(level=logging.DEBUG, format=default_formatter, force=True)
|
40 |
-
|
41 |
-
logger = logging.getLogger(name=name)
|
42 |
-
|
43 |
-
# create a console handler
|
44 |
-
ch = logging.StreamHandler()
|
45 |
-
ch.setLevel(logging.DEBUG)
|
46 |
-
|
47 |
-
# create formatter and add to the console
|
48 |
-
formatter = logging.Formatter(formatter)
|
49 |
-
ch.setFormatter(formatter)
|
50 |
-
|
51 |
-
# add the console handler to logger
|
52 |
-
logger.addHandler(ch)
|
53 |
-
return logger, logging_uuid
|
54 |
-
|
55 |
-
|
56 |
-
def set_uuid_logging(func: Callable) -> Callable:
|
57 |
-
@wraps(func)
|
58 |
-
def wrapper(*args, **kwargs):
|
59 |
-
import uuid
|
60 |
-
|
61 |
-
current_uuid = f"{uuid.uuid4()}"
|
62 |
-
logging_uuid.set(current_uuid)
|
63 |
-
return func(*args, **kwargs)
|
64 |
-
|
65 |
-
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis_lisa_on_cuda/utilities/type_hints.py
DELETED
@@ -1,111 +0,0 @@
|
|
1 |
-
"""custom type hints"""
|
2 |
-
from enum import IntEnum, Enum
|
3 |
-
from typing import TypedDict
|
4 |
-
|
5 |
-
from affine import Affine
|
6 |
-
from numpy import ndarray
|
7 |
-
from pydantic import BaseModel
|
8 |
-
|
9 |
-
|
10 |
-
tuple_ndarray_transform = tuple[ndarray, Affine]
|
11 |
-
|
12 |
-
|
13 |
-
class XYZDefaultProvidersNames(str, Enum):
|
14 |
-
"""Default xyz provider names"""
|
15 |
-
DEFAULT_TILES_NAME_SHORT = "openstreetmap"
|
16 |
-
DEFAULT_TILES_NAME = "openstreetmap.mapnik"
|
17 |
-
|
18 |
-
|
19 |
-
class XYZTerrainProvidersNames(str, Enum):
|
20 |
-
"""Custom xyz provider names for digital elevation models"""
|
21 |
-
MAPBOX_TERRAIN_TILES_NAME = "mapbox.terrain-rgb"
|
22 |
-
NEXTZEN_TERRAIN_TILES_NAME = "nextzen.terrarium"
|
23 |
-
|
24 |
-
|
25 |
-
class LatLngDict(BaseModel):
|
26 |
-
"""Generic geographic latitude-longitude type"""
|
27 |
-
lat: float
|
28 |
-
lng: float
|
29 |
-
|
30 |
-
|
31 |
-
class ContentTypes(str, Enum):
|
32 |
-
"""Segment Anything: validation point prompt type"""
|
33 |
-
APPLICATION_JSON = "application/json"
|
34 |
-
TEXT_PLAIN = "text/plain"
|
35 |
-
TEXT_HTML = "text/html"
|
36 |
-
|
37 |
-
|
38 |
-
class PromptPointType(str, Enum):
|
39 |
-
"""Segment Anything: validation point prompt type"""
|
40 |
-
point = "point"
|
41 |
-
|
42 |
-
|
43 |
-
class PromptRectangleType(str, Enum):
|
44 |
-
"""Segment Anything: validation rectangle prompt type"""
|
45 |
-
rectangle = "rectangle"
|
46 |
-
|
47 |
-
|
48 |
-
class PromptLabel(IntEnum):
|
49 |
-
"""Valid prompt label type"""
|
50 |
-
EXCLUDE = 0
|
51 |
-
INCLUDE = 1
|
52 |
-
|
53 |
-
|
54 |
-
class ImagePixelCoordinates(TypedDict):
|
55 |
-
"""Image pixel coordinates type"""
|
56 |
-
x: int
|
57 |
-
y: int
|
58 |
-
|
59 |
-
|
60 |
-
class RawBBox(BaseModel):
|
61 |
-
"""Input lambda bbox request type (not yet parsed)"""
|
62 |
-
ne: LatLngDict
|
63 |
-
sw: LatLngDict
|
64 |
-
|
65 |
-
|
66 |
-
class RawPromptPoint(BaseModel):
|
67 |
-
"""Input lambda prompt request of type 'PromptPointType' - point (not yet parsed)"""
|
68 |
-
type: PromptPointType
|
69 |
-
data: LatLngDict
|
70 |
-
label: PromptLabel
|
71 |
-
|
72 |
-
|
73 |
-
class RawPromptRectangle(BaseModel):
|
74 |
-
"""Input lambda prompt request of type 'PromptRectangleType' - rectangle (not yet parsed)"""
|
75 |
-
type: PromptRectangleType
|
76 |
-
data: RawBBox
|
77 |
-
|
78 |
-
def get_type_str(self):
|
79 |
-
return self.type
|
80 |
-
|
81 |
-
|
82 |
-
class ApiRequestBody(BaseModel):
|
83 |
-
"""Input lambda request validator type (not yet parsed)"""
|
84 |
-
id: str = ""
|
85 |
-
bbox: RawBBox
|
86 |
-
prompt: list[RawPromptPoint | RawPromptRectangle]
|
87 |
-
zoom: int | float
|
88 |
-
source_type: str = "OpenStreetMap.Mapnik"
|
89 |
-
debug: bool = False
|
90 |
-
|
91 |
-
|
92 |
-
class StringPromptApiRequestBody(BaseModel):
|
93 |
-
"""Input lambda request validator type (not yet parsed)"""
|
94 |
-
id: str = ""
|
95 |
-
bbox: RawBBox
|
96 |
-
string_prompt: str
|
97 |
-
zoom: int | float
|
98 |
-
source_type: str = "OpenStreetMap.Mapnik"
|
99 |
-
debug: bool = False
|
100 |
-
|
101 |
-
|
102 |
-
class ApiResponseBodyFailure(BaseModel):
|
103 |
-
duration_run: float
|
104 |
-
message: str
|
105 |
-
request_id: str
|
106 |
-
|
107 |
-
|
108 |
-
class ApiResponseBodySuccess(ApiResponseBodyFailure):
|
109 |
-
n_predictions: int
|
110 |
-
geojson: str
|
111 |
-
n_shapes_geojson: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/create_folders_and_variables_if_not_exists.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
-
|
6 |
-
|
7 |
-
def stats_pathname(pathname: Path | str):
|
8 |
-
current_pathname = Path(pathname)
|
9 |
-
return current_pathname.is_dir()
|
10 |
-
|
11 |
-
|
12 |
-
def create_folder_if_not_exists(pathname: Path | str):
|
13 |
-
current_pathname = Path(pathname)
|
14 |
-
try:
|
15 |
-
print(f"Pathname exists? {current_pathname.exists()}, That's a folder? {current_pathname.is_dir()}...")
|
16 |
-
logging.info(f"Pathname exists? {current_pathname.exists()}, That's a folder? {current_pathname.is_dir()}...")
|
17 |
-
current_pathname.unlink(missing_ok=True)
|
18 |
-
except PermissionError as pe:
|
19 |
-
print(f"permission denied on removing pathname before folder creation:{pe}.")
|
20 |
-
logging.error(f"permission denied on removing pathname before folder creation:{pe}.")
|
21 |
-
except IsADirectoryError as errdir:
|
22 |
-
print(f"that's a directory:{errdir}.")
|
23 |
-
logging.error(f"that's a directory:{errdir}.")
|
24 |
-
|
25 |
-
print(f"Creating pathname: {current_pathname} ...")
|
26 |
-
logging.info(f"Creating pathname: {current_pathname} ...")
|
27 |
-
current_pathname.mkdir(mode=0o770, parents=True, exist_ok=True)
|
28 |
-
|
29 |
-
print(f"assertion: pathname exists and is a folder: {current_pathname} ...")
|
30 |
-
logging.info(f"assertion: pathname exists and is a folder: {current_pathname} ...")
|
31 |
-
assert current_pathname.is_dir()
|
32 |
-
|
33 |
-
|
34 |
-
def run_folder_creation():
|
35 |
-
folders_string = os.getenv("FOLDERS_MAP")
|
36 |
-
try:
|
37 |
-
folders_dict = json.loads(folders_string)
|
38 |
-
for folder_env_ref, folder_env_path in folders_dict.items():
|
39 |
-
print(f"folder_env_ref:{folder_env_ref}, folder_env_path:{folder_env_path}.")
|
40 |
-
logging.info(f"folder_env_ref:{folder_env_ref}, folder_env_path:{folder_env_path}.")
|
41 |
-
create_folder_if_not_exists(folder_env_path)
|
42 |
-
print("========")
|
43 |
-
assert os.getenv(folder_env_ref) == folder_env_path
|
44 |
-
except (json.JSONDecodeError, TypeError) as jde:
|
45 |
-
print(f"jde:{jde}.")
|
46 |
-
logging.error(f"jde:{jde}.")
|
47 |
-
print("double check your variables, e.g. for mispelling like 'FOLDER_MAP'...")
|
48 |
-
logging.info("double check your variables, e.g. for mispelling like 'FOLDER_MAP' instead than 'FOLDERS_MAP'...")
|
49 |
-
for k_env, v_env in dict(os.environ).items():
|
50 |
-
print(f"{k_env}, v_env:{v_env}.")
|
51 |
-
logging.info(f"{k_env}, v_env:{v_env}.")
|
52 |
-
|
53 |
-
|
54 |
-
if __name__ == '__main__':
|
55 |
-
run_folder_creation()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/extract-openapi-fastapi.py
CHANGED
@@ -3,12 +3,13 @@ import argparse
|
|
3 |
import json
|
4 |
import logging
|
5 |
import sys
|
|
|
6 |
|
7 |
import yaml
|
8 |
from uvicorn.importer import import_from_string
|
9 |
|
10 |
-
from samgis_lisa_on_cuda import PROJECT_ROOT_FOLDER
|
11 |
|
|
|
12 |
parser = argparse.ArgumentParser(prog="extract-openapi-fastapi.py")
|
13 |
parser.add_argument("app", help='App import string. Eg. "main:app"', default="main:app")
|
14 |
parser.add_argument("--app-dir", help="Directory containing the app", default=None)
|
|
|
3 |
import json
|
4 |
import logging
|
5 |
import sys
|
6 |
+
from pathlib import Path
|
7 |
|
8 |
import yaml
|
9 |
from uvicorn.importer import import_from_string
|
10 |
|
|
|
11 |
|
12 |
+
PROJECT_ROOT_FOLDER = Path(globals().get("__file__", "./_")).absolute().parent.parent
|
13 |
parser = argparse.ArgumentParser(prog="extract-openapi-fastapi.py")
|
14 |
parser.add_argument("app", help='App import string. Eg. "main:app"', default="main:app")
|
15 |
parser.add_argument("--app-dir", help="Directory containing the app", default=None)
|
scripts/extract-openapi-lambda.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
|
3 |
-
from samgis_lisa_on_cuda import PROJECT_ROOT_FOLDER
|
4 |
-
|
5 |
-
if __name__ == '__main__':
|
6 |
-
from samgis_lisa_on_cuda.utilities.type_hints import ApiRequestBody, ApiResponseBodyFailure, ApiResponseBodySuccess
|
7 |
-
|
8 |
-
with open(PROJECT_ROOT_FOLDER / "docs" / "specs" / "openapi_lambda_wip.json", "w") as output_json:
|
9 |
-
json.dump({
|
10 |
-
"ApiRequestBody": ApiRequestBody.model_json_schema(),
|
11 |
-
"ApiResponseBodyFailure": ApiResponseBodyFailure.model_json_schema(),
|
12 |
-
"ApiResponseBodySuccess": ApiResponseBodySuccess.model_json_schema()
|
13 |
-
}, output_json)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
from
|
2 |
|
3 |
|
|
|
4 |
TEST_ROOT_FOLDER = PROJECT_ROOT_FOLDER / "tests"
|
5 |
TEST_EVENTS_FOLDER = TEST_ROOT_FOLDER / "events"
|
6 |
LOCAL_URL_TILE = "http://localhost:8000/lambda_handler/{z}/{x}/{y}.png"
|
|
|
1 |
+
from pathlib import Path
|
2 |
|
3 |
|
4 |
+
PROJECT_ROOT_FOLDER = Path(globals().get("__file__", "./_")).absolute().parent.parent
|
5 |
TEST_ROOT_FOLDER = PROJECT_ROOT_FOLDER / "tests"
|
6 |
TEST_EVENTS_FOLDER = TEST_ROOT_FOLDER / "events"
|
7 |
LOCAL_URL_TILE = "http://localhost:8000/lambda_handler/{z}/{x}/{y}.png"
|
tests/events/lambda_handler/10/550/391.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/550/392.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/550/393.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/551/391.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/551/392.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/551/393.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/552/391.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/552/392.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/552/393.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/553/391.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/553/392.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/553/393.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/554/391.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/554/392.png
CHANGED
Git LFS Details
|
tests/events/lambda_handler/10/554/393.png
CHANGED
Git LFS Details
|
tests/io/test_coordinates_pixel_conversion.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
|
3 |
-
from samgis_lisa_on_cuda.io.coordinates_pixel_conversion import get_latlng_to_pixel_coordinates
|
4 |
-
from samgis_lisa_on_cuda.utilities.type_hints import LatLngDict
|
5 |
-
from tests import TEST_EVENTS_FOLDER
|
6 |
-
|
7 |
-
|
8 |
-
def test_get_latlng_to_pixel_coordinates():
|
9 |
-
name_fn = "get_latlng_to_pixel_coordinates"
|
10 |
-
|
11 |
-
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
|
12 |
-
inputs_outputs = json.load(tst_json)
|
13 |
-
for k, input_output in inputs_outputs.items():
|
14 |
-
print(f"k:{k}")
|
15 |
-
current_input = input_output["input"]
|
16 |
-
zoom = current_input["zoom"]
|
17 |
-
latlng_origin_ne = LatLngDict.model_validate(current_input["latlng_origin_ne"])
|
18 |
-
latlng_origin_sw = LatLngDict.model_validate(current_input["latlng_origin_sw"])
|
19 |
-
latlng_current_point = LatLngDict.model_validate(current_input["latlng_current_point"])
|
20 |
-
output = get_latlng_to_pixel_coordinates(
|
21 |
-
latlng_origin_ne=latlng_origin_ne,
|
22 |
-
latlng_origin_sw=latlng_origin_sw,
|
23 |
-
latlng_current_point=latlng_current_point,
|
24 |
-
zoom=zoom,
|
25 |
-
k=k
|
26 |
-
)
|
27 |
-
assert output == input_output["output"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/io/test_geo_helpers.py
DELETED
@@ -1,103 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import unittest
|
3 |
-
import numpy as np
|
4 |
-
import shapely
|
5 |
-
|
6 |
-
from samgis_lisa_on_cuda.io.geo_helpers import load_affine_transformation_from_matrix
|
7 |
-
from tests import TEST_EVENTS_FOLDER
|
8 |
-
|
9 |
-
|
10 |
-
class TestGeoHelpers(unittest.TestCase):
|
11 |
-
def test_load_affine_transformation_from_matrix(self):
|
12 |
-
name_fn = "samexporter_predict"
|
13 |
-
|
14 |
-
expected_output = {
|
15 |
-
'europe': (
|
16 |
-
1524458.6551710723, 0.0, 152.87405657035242, 4713262.318571913, -762229.3275855362, -2356860.470370812
|
17 |
-
),
|
18 |
-
'north_america': (
|
19 |
-
-13855281.495084189, 0.0, 1222.9924525628194, 6732573.451358326, 6927640.747542094, -3368121.214358007
|
20 |
-
),
|
21 |
-
'oceania': (
|
22 |
-
7269467.138033403, 0.0, 9783.93962050256, -166326.9735485418, -3634733.5690167015, 68487.57734351706
|
23 |
-
),
|
24 |
-
'south_america': (
|
25 |
-
-7922544.351904369, 0.0, 305.74811314070394, -5432228.234830927, 3961272.1759521845, 2715655.4952457524
|
26 |
-
)}
|
27 |
-
|
28 |
-
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
|
29 |
-
inputs_outputs = json.load(tst_json)
|
30 |
-
for k, input_output in inputs_outputs.items():
|
31 |
-
print(f"k:{k}.")
|
32 |
-
|
33 |
-
output = load_affine_transformation_from_matrix(input_output["input"]["matrix"])
|
34 |
-
assert output.to_shapely() == expected_output[k]
|
35 |
-
|
36 |
-
def test_load_affine_transformation_from_matrix_value_error(self):
|
37 |
-
name_fn = "samexporter_predict"
|
38 |
-
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
|
39 |
-
inputs_outputs = json.load(tst_json)
|
40 |
-
with self.assertRaises(ValueError):
|
41 |
-
try:
|
42 |
-
io_value_error = inputs_outputs["europe"]["input"]["matrix"][:5]
|
43 |
-
load_affine_transformation_from_matrix(io_value_error)
|
44 |
-
except ValueError as ve:
|
45 |
-
print(f"ve:{ve}.")
|
46 |
-
self.assertEqual(str(ve), "Expected 6 coefficients, found 5; argument type: <class 'list'>.")
|
47 |
-
raise ve
|
48 |
-
|
49 |
-
def test_load_affine_transformation_from_matrix_exception(self):
|
50 |
-
name_fn = "samexporter_predict"
|
51 |
-
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
|
52 |
-
inputs_outputs = json.load(tst_json)
|
53 |
-
with self.assertRaises(Exception):
|
54 |
-
try:
|
55 |
-
io_exception = inputs_outputs["europe"]["input"]["matrix"]
|
56 |
-
io_exception[0] = "ciao"
|
57 |
-
load_affine_transformation_from_matrix(io_exception)
|
58 |
-
except Exception as e:
|
59 |
-
print(f"e:{e}.")
|
60 |
-
self.assertEqual(str(e), "exception:could not convert string to float: 'ciao', "
|
61 |
-
"check https://github.com/rasterio/affine project for updates")
|
62 |
-
raise e
|
63 |
-
|
64 |
-
def test_get_vectorized_raster_as_geojson_ok(self):
|
65 |
-
from rasterio.transform import Affine
|
66 |
-
from samgis_lisa_on_cuda.io.geo_helpers import get_vectorized_raster_as_geojson
|
67 |
-
|
68 |
-
name_fn = "samexporter_predict"
|
69 |
-
|
70 |
-
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
|
71 |
-
inputs_outputs = json.load(tst_json)
|
72 |
-
for k, input_output in inputs_outputs.items():
|
73 |
-
print(f"k:{k}.")
|
74 |
-
mask = np.load(TEST_EVENTS_FOLDER / name_fn / k / "mask.npy")
|
75 |
-
|
76 |
-
transform = Affine.from_gdal(*input_output["input"]["matrix"])
|
77 |
-
output = get_vectorized_raster_as_geojson(mask=mask, transform=transform)
|
78 |
-
assert output["n_shapes_geojson"] == input_output["output"]["n_shapes_geojson"]
|
79 |
-
output_geojson = shapely.from_geojson(output["geojson"])
|
80 |
-
expected_output_geojson = shapely.from_geojson(input_output["output"]["geojson"])
|
81 |
-
assert shapely.equals_exact(output_geojson, expected_output_geojson, tolerance=0.000006)
|
82 |
-
|
83 |
-
def test_get_vectorized_raster_as_geojson_fail(self):
|
84 |
-
from samgis_lisa_on_cuda.io.geo_helpers import get_vectorized_raster_as_geojson
|
85 |
-
|
86 |
-
name_fn = "samexporter_predict"
|
87 |
-
|
88 |
-
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
|
89 |
-
inputs_outputs = json.load(tst_json)
|
90 |
-
for k, input_output in inputs_outputs.items():
|
91 |
-
print(f"k:{k}.")
|
92 |
-
mask = np.load(TEST_EVENTS_FOLDER / name_fn / k / "mask.npy")
|
93 |
-
|
94 |
-
# Could be also another generic Exception, here we intercept TypeError caused by wrong matrix input on
|
95 |
-
# rasterio.Affine.from_gdal() wrapped by get_affine_transform_from_gdal()
|
96 |
-
with self.assertRaises(IndexError):
|
97 |
-
try:
|
98 |
-
wrong_matrix = 1.0,
|
99 |
-
get_vectorized_raster_as_geojson(mask=mask, transform=wrong_matrix)
|
100 |
-
except IndexError as te:
|
101 |
-
print(f"te:{te}.")
|
102 |
-
self.assertEqual(str(te), 'tuple index out of range')
|
103 |
-
raise te
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/io/test_raster_helpers.py
DELETED
@@ -1,255 +0,0 @@
|
|
1 |
-
import unittest
|
2 |
-
from unittest.mock import patch
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
from samgis_core.utilities.utilities import hash_calculate
|
6 |
-
from samgis_lisa_on_cuda.io import raster_helpers
|
7 |
-
|
8 |
-
|
9 |
-
def get_three_channels(size=5, param1=1000, param2=3, param3=-88):
|
10 |
-
arr_base = np.arange(size*size).reshape(size, size) / size**2
|
11 |
-
channel_0 = arr_base * param1
|
12 |
-
channel_1 = arr_base * param2
|
13 |
-
channel_2 = arr_base * param3
|
14 |
-
return channel_0, channel_1, channel_2
|
15 |
-
|
16 |
-
|
17 |
-
def helper_bell(size=10, param1=0.1, param2=2):
|
18 |
-
x = np.linspace(-size, size, num=size**2)
|
19 |
-
y = np.linspace(-size, size, num=size**2)
|
20 |
-
x, y = np.meshgrid(x, y)
|
21 |
-
return np.exp(-param1 * x ** param2 - param1 * y ** param2)
|
22 |
-
|
23 |
-
|
24 |
-
arr_5x5x5 = np.arange(125).reshape((5, 5, 5)) / 25
|
25 |
-
arr = np.arange(25).resize((5, 5))
|
26 |
-
channel0, channel1, channel2 = get_three_channels()
|
27 |
-
z = helper_bell()
|
28 |
-
slope_z_cellsize3, curvature_z_cellsize3 = raster_helpers.get_slope_curvature(z, slope_cellsize=3)
|
29 |
-
|
30 |
-
|
31 |
-
class Test(unittest.TestCase):
|
32 |
-
|
33 |
-
def test_get_rgb_prediction_image_real(self):
|
34 |
-
output = raster_helpers.get_rgb_prediction_image(z, slope_cellsize=61, invert_image=True)
|
35 |
-
hash_output = hash_calculate(output)
|
36 |
-
assert hash_output == b'QpQ9yxgCLw9cf3klNFKNFXIDHaSkuiZxkbpeQApR8pA='
|
37 |
-
output = raster_helpers.get_rgb_prediction_image(z, slope_cellsize=61, invert_image=False)
|
38 |
-
hash_output = hash_calculate(output)
|
39 |
-
assert hash_output == b'Y+iXO9w/sKzNVOw2rBh2JrVGJUFRqaa8/0F9hpevmLs='
|
40 |
-
|
41 |
-
@patch.object(raster_helpers, "get_slope_curvature")
|
42 |
-
@patch.object(raster_helpers, "normalize_array_list")
|
43 |
-
@patch.object(raster_helpers, "get_rgb_image")
|
44 |
-
def test_get_rgb_prediction_image_mocked(self, get_rgb_image_mocked, normalize_array_list, get_slope_curvature):
|
45 |
-
local_arr = np.array(z * 100, dtype=np.uint8)
|
46 |
-
|
47 |
-
get_slope_curvature.return_value = slope_z_cellsize3, curvature_z_cellsize3
|
48 |
-
normalize_array_list.side_effect = None
|
49 |
-
get_rgb_image_mocked.return_value = np.bitwise_not(local_arr)
|
50 |
-
output = raster_helpers.get_rgb_prediction_image(local_arr, slope_cellsize=61, invert_image=True)
|
51 |
-
hash_output = hash_calculate(output)
|
52 |
-
assert hash_output == b'BPIyVH64RgVunj42EuQAx4/v59Va8ZAjcMnuiGNqTT0='
|
53 |
-
get_rgb_image_mocked.return_value = local_arr
|
54 |
-
output = raster_helpers.get_rgb_prediction_image(local_arr, slope_cellsize=61, invert_image=False)
|
55 |
-
hash_output = hash_calculate(output)
|
56 |
-
assert hash_output == b'XX54sdLQQUrhkUHT6ikQZYSloMYDSfh/AGITDq6jnRM='
|
57 |
-
|
58 |
-
@patch.object(raster_helpers, "get_slope_curvature")
|
59 |
-
def test_get_rgb_prediction_image_value_error(self, get_slope_curvature):
|
60 |
-
msg = "this is a value error"
|
61 |
-
get_slope_curvature.side_effect = ValueError(msg)
|
62 |
-
|
63 |
-
with self.assertRaises(ValueError):
|
64 |
-
try:
|
65 |
-
raster_helpers.get_rgb_prediction_image(arr, slope_cellsize=3)
|
66 |
-
except ValueError as ve:
|
67 |
-
self.assertEqual(str(ve), msg)
|
68 |
-
raise ve
|
69 |
-
|
70 |
-
def test_get_rgb_image(self):
|
71 |
-
output = raster_helpers.get_rgb_image(channel0, channel1, channel2, invert_image=True)
|
72 |
-
hash_output = hash_calculate(output)
|
73 |
-
assert hash_output == b'YVnRWla5Ptfet6reSfM+OEIsGytLkeso6X+CRs34YHk='
|
74 |
-
output = raster_helpers.get_rgb_image(channel0, channel1, channel2, invert_image=False)
|
75 |
-
hash_output = hash_calculate(output)
|
76 |
-
assert hash_output == b'LC/kIZGUZULSrwwSXCeP1My2spTZdW9D7LH+tltwERs='
|
77 |
-
|
78 |
-
def test_get_rgb_image_value_error_1(self):
|
79 |
-
with self.assertRaises(ValueError):
|
80 |
-
try:
|
81 |
-
raster_helpers.get_rgb_image(arr_5x5x5, arr_5x5x5, arr_5x5x5, invert_image=True)
|
82 |
-
except ValueError as ve:
|
83 |
-
self.assertEqual(f"arr_size, wrong type:{type(arr_5x5x5)} or arr_size:{arr_5x5x5.shape}.", str(ve))
|
84 |
-
raise ve
|
85 |
-
|
86 |
-
def test_get_rgb_image_value_error2(self):
|
87 |
-
arr_0 = np.arange(25).reshape((5, 5))
|
88 |
-
arr_1 = np.arange(4).reshape((2, 2))
|
89 |
-
with self.assertRaises(ValueError):
|
90 |
-
try:
|
91 |
-
raster_helpers.get_rgb_image(arr_0, arr_1, channel2, invert_image=True)
|
92 |
-
except ValueError as ve:
|
93 |
-
self.assertEqual('could not broadcast input array from shape (2,2) into shape (5,5)', str(ve))
|
94 |
-
raise ve
|
95 |
-
|
96 |
-
def test_get_slope_curvature(self):
|
97 |
-
slope_output, curvature_output = raster_helpers.get_slope_curvature(z, slope_cellsize=3)
|
98 |
-
hash_curvature = hash_calculate(curvature_output)
|
99 |
-
hash_slope = hash_calculate(slope_output)
|
100 |
-
assert hash_curvature == b'LAL9JFOjJP9D6X4X3fVCpnitx9VPM9drS5YMHwMZ3iE='
|
101 |
-
assert hash_slope == b'IYf6x4G0lmR47j6HRS5kUYWdtmimhLz2nak8py75nwc='
|
102 |
-
|
103 |
-
def test_get_slope_curvature_value_error(self):
|
104 |
-
from samgis_lisa_on_cuda.io import raster_helpers
|
105 |
-
|
106 |
-
with self.assertRaises(ValueError):
|
107 |
-
try:
|
108 |
-
raster_helpers.get_slope_curvature(np.array(1), slope_cellsize=3)
|
109 |
-
except ValueError as ve:
|
110 |
-
self.assertEqual('not enough values to unpack (expected 2, got 0)', str(ve))
|
111 |
-
raise ve
|
112 |
-
|
113 |
-
def test_calculate_slope(self):
|
114 |
-
slope_output = raster_helpers.calculate_slope(z, cell_size=3)
|
115 |
-
hash_output = hash_calculate(slope_output)
|
116 |
-
assert hash_output == b'IYf6x4G0lmR47j6HRS5kUYWdtmimhLz2nak8py75nwc='
|
117 |
-
|
118 |
-
def test_calculate_slope_value_error(self):
|
119 |
-
with self.assertRaises(ValueError):
|
120 |
-
try:
|
121 |
-
raster_helpers.calculate_slope(np.array(1), cell_size=3)
|
122 |
-
except ValueError as ve:
|
123 |
-
self.assertEqual('not enough values to unpack (expected 2, got 0)', str(ve))
|
124 |
-
raise ve
|
125 |
-
|
126 |
-
def test_normalize_array(self):
|
127 |
-
def check_ndarrays_almost_equal(cls, arr1, arr2, places, check_type="float", check_ndiff=1):
|
128 |
-
count_abs_diff = 0
|
129 |
-
for list00, list01 in zip(arr1.tolist(), arr2.tolist()):
|
130 |
-
for el00, el01 in zip(list00, list01):
|
131 |
-
ndiff = abs(el00 - el01)
|
132 |
-
if el00 != el01:
|
133 |
-
count_abs_diff += 1
|
134 |
-
if check_type == "float":
|
135 |
-
cls.assertAlmostEqual(el00, el01, places=places)
|
136 |
-
cls.assertLess(ndiff, check_ndiff) # cls.assertTrue(ndiff < check_ndiff)
|
137 |
-
print("count_abs_diff:", count_abs_diff)
|
138 |
-
|
139 |
-
normalized_array = raster_helpers.normalize_array(z)
|
140 |
-
hash_output = hash_calculate(normalized_array)
|
141 |
-
assert hash_output == b'MPkQwiiQa5NxL7LDvCS9V143YUEJT/Qh1aNEKc/Ehvo='
|
142 |
-
|
143 |
-
mult_variable = 3.423
|
144 |
-
test_array_input = np.arange(256).reshape((16, 16))
|
145 |
-
test_array_output = raster_helpers.normalize_array(test_array_input * mult_variable)
|
146 |
-
check_ndarrays_almost_equal(self, test_array_output, test_array_input, places=8)
|
147 |
-
|
148 |
-
test_array_output1 = raster_helpers.normalize_array(test_array_input * mult_variable, high=128, norm_type="int")
|
149 |
-
o = np.arange(256).reshape((16, 16)) / 2
|
150 |
-
expected_array_output1 = o.astype(int)
|
151 |
-
check_ndarrays_almost_equal(
|
152 |
-
self, test_array_output1, expected_array_output1, places=2, check_type="int", check_ndiff=2)
|
153 |
-
|
154 |
-
@patch.object(np, "nanmin")
|
155 |
-
@patch.object(np, "nanmax")
|
156 |
-
def test_normalize_array_floating_point_error_mocked(self, nanmax_mocked, nanmin_mocked):
|
157 |
-
nanmax_mocked.return_value = 100
|
158 |
-
nanmin_mocked.return_value = 100
|
159 |
-
|
160 |
-
with self.assertRaises(ValueError):
|
161 |
-
try:
|
162 |
-
raster_helpers.normalize_array(
|
163 |
-
np.arange(25).reshape((5, 5))
|
164 |
-
)
|
165 |
-
except ValueError as ve:
|
166 |
-
self.assertEqual(
|
167 |
-
"normalize_array:::h_arr_max:100,h_min_arr:100,fe:divide by zero encountered in divide.",
|
168 |
-
str(ve)
|
169 |
-
)
|
170 |
-
raise ve
|
171 |
-
|
172 |
-
@patch.object(np, "nanmin")
|
173 |
-
@patch.object(np, "nanmax")
|
174 |
-
def test_normalize_array_exception_error_mocked(self, nanmax_mocked, nanmin_mocked):
|
175 |
-
nanmax_mocked.return_value = 100
|
176 |
-
nanmin_mocked.return_value = np.NaN
|
177 |
-
|
178 |
-
with self.assertRaises(ValueError):
|
179 |
-
try:
|
180 |
-
raster_helpers.normalize_array(
|
181 |
-
np.arange(25).reshape((5, 5))
|
182 |
-
)
|
183 |
-
except ValueError as ve:
|
184 |
-
self.assertEqual("cannot convert float NaN to integer", str(ve))
|
185 |
-
raise ve
|
186 |
-
|
187 |
-
def test_normalize_array_value_error(self):
|
188 |
-
with self.assertRaises(ValueError):
|
189 |
-
try:
|
190 |
-
raster_helpers.normalize_array(
|
191 |
-
np.zeros((5, 5))
|
192 |
-
)
|
193 |
-
except ValueError as ve:
|
194 |
-
self.assertEqual(
|
195 |
-
"normalize_array::empty array '',h_min_arr:0.0,h_arr_max:0.0,h_diff:0.0, " 'dtype:float64.',
|
196 |
-
str(ve)
|
197 |
-
)
|
198 |
-
raise ve
|
199 |
-
|
200 |
-
def test_normalize_array_list(self):
|
201 |
-
normalized_array = raster_helpers.normalize_array_list([channel0, channel1, channel2])
|
202 |
-
hash_output = hash_calculate(normalized_array)
|
203 |
-
assert hash_output == b'+6IbhIpyb3vPElTgqqPkQdIR0umf4uFP2c7t5IaBVvI='
|
204 |
-
|
205 |
-
test_norm_list_output2 = raster_helpers.normalize_array_list(
|
206 |
-
[channel0, channel1, channel2], exaggerations_list=[2.0, 3.0, 5.0])
|
207 |
-
hash_variable2 = hash_calculate(test_norm_list_output2)
|
208 |
-
assert hash_variable2 == b'yYCYWCKO3i8NYsWk/wgYOzSRRLSLUprEs7mChJkdL+A='
|
209 |
-
|
210 |
-
def test_normalize_array_list_value_error(self):
|
211 |
-
with self.assertRaises(ValueError):
|
212 |
-
try:
|
213 |
-
raster_helpers.normalize_array_list([])
|
214 |
-
except ValueError as ve:
|
215 |
-
self.assertEqual("input list can't be empty:[].", str(ve))
|
216 |
-
raise ve
|
217 |
-
|
218 |
-
def test_check_empty_array(self):
|
219 |
-
a = np.zeros((10, 10))
|
220 |
-
b = np.ones((10, 10))
|
221 |
-
c = np.ones((10, 10)) * 2
|
222 |
-
d = np.zeros((10, 10))
|
223 |
-
d[1, 1] = np.nan
|
224 |
-
e = np.ones((10, 10)) * 3
|
225 |
-
e[1, 1] = np.nan
|
226 |
-
|
227 |
-
self.assertTrue(raster_helpers.check_empty_array(a, 999))
|
228 |
-
self.assertTrue(raster_helpers.check_empty_array(b, 0))
|
229 |
-
self.assertTrue(raster_helpers.check_empty_array(c, 2))
|
230 |
-
self.assertTrue(raster_helpers.check_empty_array(d, 0))
|
231 |
-
self.assertTrue(raster_helpers.check_empty_array(e, 3))
|
232 |
-
self.assertFalse(raster_helpers.check_empty_array(z, 3))
|
233 |
-
|
234 |
-
def test_get_nextzen_terrain_rgb_formula(self):
|
235 |
-
output = raster_helpers.get_nextzen_terrain_rgb_formula(channel0, channel1, channel2)
|
236 |
-
hash_output = hash_calculate(output)
|
237 |
-
assert hash_output == b'3KJ81YKmQRdccRZARbByfwo1iMVLj8xxz9mfsWki/qA='
|
238 |
-
|
239 |
-
def test_get_mapbox__terrain_rgb_formula(self):
|
240 |
-
output = raster_helpers.get_mapbox__terrain_rgb_formula(channel0, channel1, channel2)
|
241 |
-
hash_output = hash_calculate(output)
|
242 |
-
assert hash_output == b'RU7CcoKoR3Fkh5LE+m48DHRVUy/vGq6UgfOFUMXx07M='
|
243 |
-
|
244 |
-
def test_get_raster_terrain_rgb_like(self):
|
245 |
-
from samgis_lisa_on_cuda.utilities.type_hints import XYZTerrainProvidersNames
|
246 |
-
|
247 |
-
arr_input = raster_helpers.get_rgb_image(channel0, channel1, channel2, invert_image=True)
|
248 |
-
output_nextzen = raster_helpers.get_raster_terrain_rgb_like(
|
249 |
-
arr_input, XYZTerrainProvidersNames.NEXTZEN_TERRAIN_TILES_NAME)
|
250 |
-
hash_nextzen = hash_calculate(output_nextzen)
|
251 |
-
assert hash_nextzen == b'+o2OTJliJkkBoqiAIGnhJ4s0xoLQ4MxHOvevLhNxysE='
|
252 |
-
output_mapbox = raster_helpers.get_raster_terrain_rgb_like(
|
253 |
-
arr_input, XYZTerrainProvidersNames.MAPBOX_TERRAIN_TILES_NAME)
|
254 |
-
hash_mapbox = hash_calculate(output_mapbox)
|
255 |
-
assert hash_mapbox == b'zWmekyKrpnmHnuDACnveCJl+o4GuhtHJmGlRDVwsce4='
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/io/test_tms2geotiff.py
DELETED
@@ -1,138 +0,0 @@
|
|
1 |
-
import unittest
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
from samgis_core.utilities.utilities import hash_calculate
|
5 |
-
|
6 |
-
from samgis_lisa_on_cuda import app_logger
|
7 |
-
from samgis_lisa_on_cuda.io.tms2geotiff import download_extent
|
8 |
-
from tests import LOCAL_URL_TILE, TEST_EVENTS_FOLDER
|
9 |
-
|
10 |
-
|
11 |
-
input_bbox = [[39.036252959636606, 15.040283203125002], [38.302869955150044, 13.634033203125002]]
|
12 |
-
|
13 |
-
|
14 |
-
class TestTms2geotiff(unittest.TestCase):
|
15 |
-
# def test_download_extent_simple_source(self):
|
16 |
-
# from rasterio import Affine
|
17 |
-
# from xyzservices import TileProvider
|
18 |
-
# from tests.local_tiles_http_server import LocalTilesHttpServer
|
19 |
-
#
|
20 |
-
# listen_port = 8000
|
21 |
-
#
|
22 |
-
# with LocalTilesHttpServer.http_server("localhost", listen_port, directory=TEST_EVENTS_FOLDER):
|
23 |
-
# pt0, pt1 = input_bbox
|
24 |
-
# zoom = 10
|
25 |
-
#
|
26 |
-
# n_lat = pt0[0]
|
27 |
-
# e_lng = pt0[1]
|
28 |
-
# s_lat = pt1[0]
|
29 |
-
# w_lng = pt1[1]
|
30 |
-
#
|
31 |
-
# source = TileProvider(name="local_tile_provider", url=LOCAL_URL_TILE, attribution="")
|
32 |
-
# img, matrix = download_extent(w=w_lng, s=s_lat, e=e_lng, n=n_lat, zoom=zoom, source=source)
|
33 |
-
# app_logger.info(f"# DOWNLOAD ENDED, shape: {img.shape} #")
|
34 |
-
# np_img = np.ascontiguousarray(img)
|
35 |
-
# output_hash = hash_calculate(np_img)
|
36 |
-
# assert output_hash == b'UmbkwbPJpRT1XXcLnLUapUDP320w7YhS/AmT3H7u+b4='
|
37 |
-
# assert Affine.to_gdal(matrix) == (
|
38 |
-
# 1517657.1966021745, 152.8740565703525, 0.0, 4726942.266183584, 0.0, -152.87405657034955)
|
39 |
-
|
40 |
-
def test_download_extent_source_with_parameter(self):
|
41 |
-
from rasterio import Affine
|
42 |
-
from xyzservices import TileProvider
|
43 |
-
from tests.local_tiles_http_server import LocalTilesHttpServer
|
44 |
-
|
45 |
-
listen_port = 8000
|
46 |
-
|
47 |
-
with LocalTilesHttpServer.http_server("localhost", listen_port, directory=TEST_EVENTS_FOLDER):
|
48 |
-
pt0, pt1 = input_bbox
|
49 |
-
zoom = 10
|
50 |
-
|
51 |
-
n_lat = pt0[0]
|
52 |
-
e_lng = pt0[1]
|
53 |
-
s_lat = pt1[0]
|
54 |
-
w_lng = pt1[1]
|
55 |
-
|
56 |
-
local_url = "http://localhost:8000/{parameter}/{z}/{x}/{y}.png"
|
57 |
-
download_extent_args_no_parameter = {"name": "local_tile_provider", "url": LOCAL_URL_TILE, "attribution": ""}
|
58 |
-
download_extent_args = {
|
59 |
-
"no_parameter": download_extent_args_no_parameter,
|
60 |
-
"with_parameter": {"url": local_url, "parameter": "lambda_handler", **download_extent_args_no_parameter}
|
61 |
-
}
|
62 |
-
for _args_names, _args in download_extent_args.items():
|
63 |
-
app_logger.info(f"args_names:{_args_names}.")
|
64 |
-
source = TileProvider(**_args)
|
65 |
-
img, matrix = download_extent(w=w_lng, s=s_lat, e=e_lng, n=n_lat, zoom=zoom, source=source)
|
66 |
-
app_logger.info(f"# DOWNLOAD ENDED, shape: {img.shape} #")
|
67 |
-
np_img = np.ascontiguousarray(img)
|
68 |
-
output_hash = hash_calculate(np_img)
|
69 |
-
assert output_hash == b'UmbkwbPJpRT1XXcLnLUapUDP320w7YhS/AmT3H7u+b4='
|
70 |
-
assert Affine.to_gdal(matrix) == (
|
71 |
-
1517657.1966021745, 152.8740565703525, 0.0, 4726942.266183584, 0.0, -152.87405657034955)
|
72 |
-
|
73 |
-
def test_download_extent_source_with_parameter_key_error(self):
|
74 |
-
from xyzservices import TileProvider
|
75 |
-
|
76 |
-
with self.assertRaises(KeyError):
|
77 |
-
try:
|
78 |
-
pt0, pt1 = input_bbox
|
79 |
-
zoom = 10
|
80 |
-
|
81 |
-
n_lat = pt0[0]
|
82 |
-
e_lng = pt0[1]
|
83 |
-
s_lat = pt1[0]
|
84 |
-
w_lng = pt1[1]
|
85 |
-
|
86 |
-
local_url_tile2 = "http://localhost:8000/{parameter}/{z}/{x}/{y}.png"
|
87 |
-
source = TileProvider(name="local_tile_provider", url=local_url_tile2, attribution="")
|
88 |
-
download_extent(w=w_lng, s=s_lat, e=e_lng, n=n_lat, zoom=zoom, source=source)
|
89 |
-
except KeyError as ke:
|
90 |
-
assert str(ke) == "'parameter'"
|
91 |
-
raise ke
|
92 |
-
|
93 |
-
def test_download_extent_io_error1(self):
|
94 |
-
|
95 |
-
with self.assertRaises(Exception):
|
96 |
-
try:
|
97 |
-
pt0, pt1 = input_bbox
|
98 |
-
zoom = 10
|
99 |
-
|
100 |
-
n_lat = pt0[0]
|
101 |
-
e_lng = pt0[1]
|
102 |
-
s_lat = pt1[0]
|
103 |
-
w_lng = pt1[1]
|
104 |
-
|
105 |
-
download_extent(w=w_lng, s=s_lat, e=e_lng, n=n_lat, zoom=zoom, source=f"http://{LOCAL_URL_TILE}")
|
106 |
-
print("exception not raised")
|
107 |
-
except ConnectionError as ioe1:
|
108 |
-
app_logger.error(f"ioe1:{ioe1}.")
|
109 |
-
msg0 = "HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /lambda_handler"
|
110 |
-
msg1 = "Caused by NewConnectionError"
|
111 |
-
msg2 = ": Failed to establish a new connection: [Errno 61] Connection refused'))"
|
112 |
-
assert msg0 in str(ioe1)
|
113 |
-
assert msg1 in str(ioe1)
|
114 |
-
assert msg2 in str(ioe1)
|
115 |
-
raise ioe1
|
116 |
-
|
117 |
-
def test_download_extent_io_error2(self):
|
118 |
-
from requests import HTTPError
|
119 |
-
from tests.local_tiles_http_server import LocalTilesHttpServer
|
120 |
-
|
121 |
-
listen_port = 8000
|
122 |
-
with LocalTilesHttpServer.http_server("localhost", listen_port, directory=TEST_EVENTS_FOLDER):
|
123 |
-
pt0, pt1 = input_bbox
|
124 |
-
zoom = 10
|
125 |
-
|
126 |
-
with self.assertRaises(HTTPError):
|
127 |
-
try:
|
128 |
-
n_lat = pt0[0]
|
129 |
-
e_lng = pt0[1]
|
130 |
-
s_lat = pt1[0]
|
131 |
-
w_lng = pt1[1]
|
132 |
-
|
133 |
-
download_extent(w=w_lng, s=s_lat, e=e_lng, n=n_lat, zoom=zoom,
|
134 |
-
source=LOCAL_URL_TILE + "_not_found_raster!")
|
135 |
-
except HTTPError as http_e:
|
136 |
-
app_logger.error(f"ae:{http_e}.")
|
137 |
-
assert "Tile URL resulted in a 404 error. Double-check your tile url:" in str(http_e)
|
138 |
-
raise http_e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/io/test_wrappers_helpers.py
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import time
|
3 |
-
import unittest
|
4 |
-
|
5 |
-
from http import HTTPStatus
|
6 |
-
from unittest.mock import patch
|
7 |
-
|
8 |
-
from samgis_lisa_on_cuda.io import wrappers_helpers
|
9 |
-
from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt, get_parsed_request_body, get_response
|
10 |
-
from samgis_lisa_on_cuda.utilities.type_hints import ApiRequestBody
|
11 |
-
from tests import TEST_EVENTS_FOLDER
|
12 |
-
|
13 |
-
|
14 |
-
class WrappersHelpersTest(unittest.TestCase):
|
15 |
-
@patch.object(time, "time")
|
16 |
-
def test_get_response(self, time_mocked):
|
17 |
-
time_diff = 108
|
18 |
-
end_run = 1000
|
19 |
-
time_mocked.return_value = end_run
|
20 |
-
start_time = end_run - time_diff
|
21 |
-
aws_request_id = "test_invoke_id"
|
22 |
-
|
23 |
-
with open(TEST_EVENTS_FOLDER / "get_response.json") as tst_json:
|
24 |
-
inputs_outputs = json.load(tst_json)
|
25 |
-
|
26 |
-
response_type = "200"
|
27 |
-
body_response = inputs_outputs[response_type]["input"]
|
28 |
-
output = get_response(HTTPStatus.OK.value, start_time, aws_request_id, body_response)
|
29 |
-
assert json.loads(output) == inputs_outputs[response_type]["output"]
|
30 |
-
|
31 |
-
response_type = "400"
|
32 |
-
response_400 = get_response(HTTPStatus.BAD_REQUEST.value, start_time, aws_request_id, {})
|
33 |
-
assert response_400 == inputs_outputs[response_type]["output"]
|
34 |
-
|
35 |
-
response_type = "422"
|
36 |
-
response_422 = get_response(HTTPStatus.UNPROCESSABLE_ENTITY.value, start_time, aws_request_id, {})
|
37 |
-
assert response_422 == inputs_outputs[response_type]["output"]
|
38 |
-
|
39 |
-
response_type = "500"
|
40 |
-
response_500 = get_response(HTTPStatus.INTERNAL_SERVER_ERROR.value, start_time, aws_request_id, {})
|
41 |
-
assert response_500 == inputs_outputs[response_type]["output"]
|
42 |
-
|
43 |
-
@staticmethod
|
44 |
-
def test_get_parsed_bbox_points():
|
45 |
-
with open(TEST_EVENTS_FOLDER / "get_parsed_bbox_prompts_single_point.json") as tst_json:
|
46 |
-
inputs_outputs = json.load(tst_json)
|
47 |
-
for k, input_output in inputs_outputs.items():
|
48 |
-
print(f"k:{k}.")
|
49 |
-
raw_body = get_parsed_request_body(**input_output["input"])
|
50 |
-
output = get_parsed_bbox_points_with_dictlist_prompt(raw_body)
|
51 |
-
assert output == input_output["output"]
|
52 |
-
|
53 |
-
@staticmethod
|
54 |
-
def test_get_parsed_bbox_other_inputs():
|
55 |
-
for json_filename in ["single_rectangle", "multi_prompt"]:
|
56 |
-
with open(TEST_EVENTS_FOLDER / f"get_parsed_bbox_prompts_{json_filename}.json") as tst_json:
|
57 |
-
inputs_outputs = json.load(tst_json)
|
58 |
-
parsed_input = ApiRequestBody.model_validate(inputs_outputs["input"])
|
59 |
-
output = get_parsed_bbox_points_with_dictlist_prompt(parsed_input)
|
60 |
-
assert output == inputs_outputs["output"]
|
61 |
-
|
62 |
-
@staticmethod
|
63 |
-
def test_get_parsed_request_body():
|
64 |
-
from samgis_core.utilities.utilities import base64_encode
|
65 |
-
|
66 |
-
input_event = {
|
67 |
-
"event": {
|
68 |
-
"bbox": {
|
69 |
-
"ne": {"lat": 38.03932961278458, "lng": 15.36808069832851},
|
70 |
-
"sw": {"lat": 37.455509218936974, "lng": 14.632807441554068}
|
71 |
-
},
|
72 |
-
"prompt": [{"type": "point", "data": {"lat": 37.0, "lng": 15.0}, "label": 0}],
|
73 |
-
"zoom": 10, "source_type": "OpenStreetMap.Mapnik", "debug": True
|
74 |
-
}
|
75 |
-
}
|
76 |
-
expected_output_dict = {
|
77 |
-
"bbox": {
|
78 |
-
"ne": {"lat": 38.03932961278458, "lng": 15.36808069832851},
|
79 |
-
"sw": {"lat": 37.455509218936974, "lng": 14.632807441554068}
|
80 |
-
},
|
81 |
-
"prompt": [{"type": "point", "data": {"lat": 37.0, "lng": 15.0}, "label": 0}],
|
82 |
-
"zoom": 10, "source_type": "OpenStreetMap.Mapnik", "debug": True
|
83 |
-
}
|
84 |
-
output = get_parsed_request_body(input_event["event"])
|
85 |
-
assert output == ApiRequestBody.model_validate(input_event["event"])
|
86 |
-
|
87 |
-
input_event_str = json.dumps(input_event["event"])
|
88 |
-
output = get_parsed_request_body(input_event_str)
|
89 |
-
assert output == ApiRequestBody.model_validate(expected_output_dict)
|
90 |
-
|
91 |
-
event = {"body": base64_encode(input_event_str).decode("utf-8")}
|
92 |
-
output = get_parsed_request_body(event)
|
93 |
-
assert output == ApiRequestBody.model_validate(expected_output_dict)
|
94 |
-
|
95 |
-
@patch.object(wrappers_helpers, "providers")
|
96 |
-
def test_get_url_tile(self, providers_mocked):
|
97 |
-
import xyzservices
|
98 |
-
from samgis_lisa_on_cuda.io.wrappers_helpers import get_url_tile
|
99 |
-
|
100 |
-
from tests import LOCAL_URL_TILE
|
101 |
-
|
102 |
-
local_tile_provider = xyzservices.TileProvider(name="local_tile_provider", url=LOCAL_URL_TILE, attribution="")
|
103 |
-
expected_output = {'name': 'local_tile_provider', 'url': LOCAL_URL_TILE, 'attribution': ''}
|
104 |
-
providers_mocked.query_name.return_value = local_tile_provider
|
105 |
-
assert get_url_tile("OpenStreetMap") == expected_output
|
106 |
-
|
107 |
-
local_url = 'http://localhost:8000/{parameter}/{z}/{x}/{y}.png'
|
108 |
-
local_tile_provider = xyzservices.TileProvider(
|
109 |
-
name="local_tile_provider_param", url=local_url, attribution="", parameter="lamda_handler"
|
110 |
-
)
|
111 |
-
providers_mocked.query_name.return_value = local_tile_provider
|
112 |
-
assert get_url_tile("OpenStreetMap.HOT") == {
|
113 |
-
"parameter": "lamda_handler", 'name': 'local_tile_provider_param', 'url': local_url, 'attribution': ''
|
114 |
-
}
|
115 |
-
|
116 |
-
@staticmethod
|
117 |
-
def test_get_url_tile_real():
|
118 |
-
from samgis_lisa_on_cuda.io.wrappers_helpers import get_url_tile
|
119 |
-
|
120 |
-
assert get_url_tile("OpenStreetMap") == {
|
121 |
-
'url': 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', 'max_zoom': 19,
|
122 |
-
'html_attribution': '© <a href="https://www.openstreetmap.org/copyright">OpenStreetMap</a> contributors',
|
123 |
-
'attribution': '(C) OpenStreetMap contributors',
|
124 |
-
'name': 'OpenStreetMap.Mapnik'}
|
125 |
-
|
126 |
-
html_attribution_hot = '© <a href="https://www.openstreetmap.org/copyright">OpenStreetMap</a> contributors, '
|
127 |
-
html_attribution_hot += 'Tiles style by <a href="https://www.hotosm.org/" target="_blank">Humanitarian '
|
128 |
-
html_attribution_hot += 'OpenStreetMap Team</a> hosted by <a href="https://openstreetmap.fr/" target="_blank">'
|
129 |
-
html_attribution_hot += 'OpenStreetMap France</a>'
|
130 |
-
attribution_hot = '(C) OpenStreetMap contributors, Tiles style by Humanitarian OpenStreetMap Team hosted by '
|
131 |
-
attribution_hot += 'OpenStreetMap France'
|
132 |
-
assert get_url_tile("OpenStreetMap.HOT") == {
|
133 |
-
'url': 'https://{s}.tile.openstreetmap.fr/hot/{z}/{x}/{y}.png', 'max_zoom': 19,
|
134 |
-
'html_attribution': html_attribution_hot, 'attribution': attribution_hot, 'name': 'OpenStreetMap.HOT'
|
135 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/local_tiles_http_server.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import time
|
3 |
-
import unittest
|
4 |
-
|
5 |
-
|
6 |
-
class LocalTilesHttpServer(unittest.TestCase):
|
7 |
-
from contextlib import contextmanager
|
8 |
-
|
9 |
-
@staticmethod
|
10 |
-
@contextmanager
|
11 |
-
def http_server(host: str, port: int, directory: str):
|
12 |
-
"""Function http_server defined within this test class to avoid pytest error "fixture 'host' not found"."""
|
13 |
-
from functools import partial
|
14 |
-
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
|
15 |
-
from threading import Thread
|
16 |
-
|
17 |
-
server = ThreadingHTTPServer(
|
18 |
-
(host, port), partial(SimpleHTTPRequestHandler, directory=directory)
|
19 |
-
)
|
20 |
-
print("dir:", directory, "#")
|
21 |
-
server_thread = Thread(target=server.serve_forever, name="http_server")
|
22 |
-
server_thread.start()
|
23 |
-
logging.info(f"listen:: host {host}, port {port}.")
|
24 |
-
|
25 |
-
try:
|
26 |
-
yield
|
27 |
-
finally:
|
28 |
-
server.shutdown()
|
29 |
-
server_thread.join()
|
30 |
-
|
31 |
-
|
32 |
-
if __name__ == '__main__':
|
33 |
-
# from tests import TEST_ROOT_FOLDER
|
34 |
-
from pathlib import Path
|
35 |
-
|
36 |
-
PROJECT_ROOT_FOLDER = Path(globals().get("__file__", "./_")).absolute().parent.parent
|
37 |
-
|
38 |
-
TEST_ROOT_FOLDER = PROJECT_ROOT_FOLDER / "tests"
|
39 |
-
TEST_EVENTS_FOLDER = TEST_ROOT_FOLDER / "events"
|
40 |
-
|
41 |
-
main_listen_port = 8000
|
42 |
-
logging.info(f"http_basedir_serve: {TEST_ROOT_FOLDER}.")
|
43 |
-
with LocalTilesHttpServer.http_server("localhost", main_listen_port, directory=str(TEST_ROOT_FOLDER)):
|
44 |
-
time.sleep(1000)
|
45 |
-
logging.info("""import time; time.sleep(10)""")
|
46 |
-
# logging.info("Http server stopped.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/prediction_api/__init__.py
DELETED
File without changes
|
tests/prediction_api/test_predictors.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
from unittest.mock import patch
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
from samgis_lisa_on_cuda.prediction_api import predictors
|
7 |
-
from samgis_lisa_on_cuda.prediction_api.predictors import get_raster_inference, samexporter_predict
|
8 |
-
from tests import TEST_EVENTS_FOLDER
|
9 |
-
|
10 |
-
|
11 |
-
@patch.object(predictors, "SegmentAnythingONNX")
|
12 |
-
def test_get_raster_inference(segment_anything_onnx_mocked):
|
13 |
-
name_fn = "samexporter_predict"
|
14 |
-
|
15 |
-
with open(TEST_EVENTS_FOLDER / f"{name_fn}.json") as tst_json:
|
16 |
-
inputs_outputs = json.load(tst_json)
|
17 |
-
for k, input_output in inputs_outputs.items():
|
18 |
-
model_mocked = segment_anything_onnx_mocked()
|
19 |
-
|
20 |
-
img = np.load(TEST_EVENTS_FOLDER / f"{name_fn}" / k / "img.npy")
|
21 |
-
inference_out = np.load(TEST_EVENTS_FOLDER / f"{name_fn}" / k / "inference_out.npy")
|
22 |
-
mask = np.load(TEST_EVENTS_FOLDER / f"{name_fn}" / k / "mask.npy")
|
23 |
-
prompt = input_output["input"]["prompt"]
|
24 |
-
model_name = input_output["input"]["model_name"]
|
25 |
-
|
26 |
-
model_mocked.embed.return_value = np.array(img)
|
27 |
-
model_mocked.embed.side_effect = None
|
28 |
-
model_mocked.predict_masks.return_value = inference_out
|
29 |
-
model_mocked.predict_masks.side_effect = None
|
30 |
-
print(f"k:{k}.")
|
31 |
-
output_mask, len_inference_out = get_raster_inference(
|
32 |
-
img=img,
|
33 |
-
prompt=prompt,
|
34 |
-
models_instance=model_mocked,
|
35 |
-
model_name=model_name
|
36 |
-
)
|
37 |
-
assert np.array_equal(output_mask, mask)
|
38 |
-
assert len_inference_out == input_output["output"]["n_predictions"]
|
39 |
-
|
40 |
-
|
41 |
-
@patch.object(predictors, "get_raster_inference")
|
42 |
-
@patch.object(predictors, "SegmentAnythingONNX")
|
43 |
-
@patch.object(predictors, "download_extent")
|
44 |
-
@patch.object(predictors, "get_vectorized_raster_as_geojson")
|
45 |
-
def test_samexporter_predict(
|
46 |
-
get_vectorized_raster_as_geojson_mocked,
|
47 |
-
download_extent_mocked,
|
48 |
-
segment_anything_onnx_mocked,
|
49 |
-
get_raster_inference_mocked
|
50 |
-
):
|
51 |
-
"""
|
52 |
-
model_instance = SegmentAnythingONNX()
|
53 |
-
img, matrix = download_extent(DEFAULT_TMS, pt0[0], pt0[1], pt1[0], pt1[1], zoom)
|
54 |
-
transform = get_affine_transform_from_gdal(matrix)
|
55 |
-
mask, n_predictions = get_raster_inference(img, prompt, models_instance, model_name)
|
56 |
-
get_vectorized_raster_as_geojson(mask, matrix)
|
57 |
-
"""
|
58 |
-
aff = 1, 2, 3, 4, 5, 6
|
59 |
-
segment_anything_onnx_mocked.return_value = "SegmentAnythingONNX_instance"
|
60 |
-
download_extent_mocked.return_value = np.zeros((10, 10)), aff
|
61 |
-
get_raster_inference_mocked.return_value = np.ones((10, 10)), 1
|
62 |
-
get_vectorized_raster_as_geojson_mocked.return_value = {"geojson": "{}", "n_shapes_geojson": 2}
|
63 |
-
output = samexporter_predict(bbox=[[1, 2], [3, 4]], prompt=[{}], zoom=10, model_name_key="fastsam")
|
64 |
-
assert output == {"n_predictions": 1, "geojson": "{}", "n_shapes_geojson": 2}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/{test_fastapi_app.py → test_app.py}
RENAMED
@@ -4,19 +4,15 @@ import unittest
|
|
4 |
from unittest.mock import patch
|
5 |
|
6 |
from fastapi.testclient import TestClient
|
7 |
-
|
8 |
-
|
9 |
-
from samgis_lisa_on_cuda.io import wrappers_helpers
|
10 |
from tests import TEST_EVENTS_FOLDER
|
11 |
-
from tests.local_tiles_http_server import LocalTilesHttpServer
|
12 |
-
from wrappers import fastapi_wrapper
|
13 |
-
from wrappers.fastapi_wrapper import app
|
14 |
|
15 |
|
16 |
infer_samgis = "/infer_samgis"
|
17 |
response_status_code = "response.status_code:{}."
|
18 |
response_body_loaded = "response.body_loaded:{}."
|
19 |
-
client = TestClient(app)
|
20 |
source = {
|
21 |
'url': 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', 'max_zoom': 19,
|
22 |
'html_attribution': '© <a href="https://www.openstreetmap.org/copyright">OpenStreetMap</a> contributors',
|
@@ -81,18 +77,18 @@ class TestFastapiApp(unittest.TestCase):
|
|
81 |
body = response.json()
|
82 |
assert body == {'msg': 'Error - Unprocessable Entity'}
|
83 |
|
84 |
-
def test_index(self):
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
|
97 |
def test_404(self):
|
98 |
response = client.get("/404")
|
@@ -119,7 +115,7 @@ class TestFastapiApp(unittest.TestCase):
|
|
119 |
assert body_loaded == {'success': False}
|
120 |
|
121 |
@patch.object(time, "time")
|
122 |
-
@patch.object(
|
123 |
def test_infer_samgis_500(self, samexporter_predict_mocked, time_mocked):
|
124 |
time_mocked.return_value = 0
|
125 |
samexporter_predict_mocked.side_effect = ValueError("I raise a value error!")
|
@@ -131,11 +127,12 @@ class TestFastapiApp(unittest.TestCase):
|
|
131 |
print(response_body_loaded.format(body))
|
132 |
assert body == {'msg': 'Error - Internal Server Error'}
|
133 |
|
134 |
-
@patch.object(
|
135 |
@patch.object(time, "time")
|
136 |
def test_infer_samgis_real_200(self, time_mocked, get_url_tile_mocked):
|
137 |
import shapely
|
138 |
import xyzservices
|
|
|
139 |
from tests import LOCAL_URL_TILE, TEST_EVENTS_FOLDER
|
140 |
|
141 |
time_mocked.return_value = 0
|
@@ -162,7 +159,7 @@ class TestFastapiApp(unittest.TestCase):
|
|
162 |
assert len(output_geojson.geoms) == 3
|
163 |
|
164 |
@patch.object(time, "time")
|
165 |
-
@patch.object(
|
166 |
def test_infer_samgis_mocked_200(self, samexporter_predict_mocked, time_mocked):
|
167 |
self.maxDiff = None
|
168 |
|
|
|
4 |
from unittest.mock import patch
|
5 |
|
6 |
from fastapi.testclient import TestClient
|
7 |
+
from samgis_web.web import web_helpers
|
8 |
+
import app
|
|
|
9 |
from tests import TEST_EVENTS_FOLDER
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
infer_samgis = "/infer_samgis"
|
13 |
response_status_code = "response.status_code:{}."
|
14 |
response_body_loaded = "response.body_loaded:{}."
|
15 |
+
client = TestClient(app.app)
|
16 |
source = {
|
17 |
'url': 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', 'max_zoom': 19,
|
18 |
'html_attribution': '© <a href="https://www.openstreetmap.org/copyright">OpenStreetMap</a> contributors',
|
|
|
77 |
body = response.json()
|
78 |
assert body == {'msg': 'Error - Unprocessable Entity'}
|
79 |
|
80 |
+
# def test_index(self):
|
81 |
+
# import subprocess
|
82 |
+
#
|
83 |
+
# subprocess.run(["pnpm", "build"], cwd=PROJECT_ROOT_FOLDER / "static")
|
84 |
+
# subprocess.run(["pnpm", "tailwindcss", "-i", "./src/input.css", "-o", "./dist/output.css"],
|
85 |
+
# cwd=PROJECT_ROOT_FOLDER / "static")
|
86 |
+
# response = client.get("/")
|
87 |
+
# assert response.status_code == 200
|
88 |
+
# html_body = response.read().decode("utf-8")
|
89 |
+
# assert "html" in html_body
|
90 |
+
# assert "head" in html_body
|
91 |
+
# assert "body" in html_body
|
92 |
|
93 |
def test_404(self):
|
94 |
response = client.get("/404")
|
|
|
115 |
assert body_loaded == {'success': False}
|
116 |
|
117 |
@patch.object(time, "time")
|
118 |
+
@patch.object(app, "samexporter_predict")
|
119 |
def test_infer_samgis_500(self, samexporter_predict_mocked, time_mocked):
|
120 |
time_mocked.return_value = 0
|
121 |
samexporter_predict_mocked.side_effect = ValueError("I raise a value error!")
|
|
|
127 |
print(response_body_loaded.format(body))
|
128 |
assert body == {'msg': 'Error - Internal Server Error'}
|
129 |
|
130 |
+
@patch.object(web_helpers, "get_url_tile")
|
131 |
@patch.object(time, "time")
|
132 |
def test_infer_samgis_real_200(self, time_mocked, get_url_tile_mocked):
|
133 |
import shapely
|
134 |
import xyzservices
|
135 |
+
from samgis_web.utilities.local_tiles_http_server import LocalTilesHttpServer
|
136 |
from tests import LOCAL_URL_TILE, TEST_EVENTS_FOLDER
|
137 |
|
138 |
time_mocked.return_value = 0
|
|
|
159 |
assert len(output_geojson.geoms) == 3
|
160 |
|
161 |
@patch.object(time, "time")
|
162 |
+
@patch.object(app, "samexporter_predict")
|
163 |
def test_infer_samgis_mocked_200(self, samexporter_predict_mocked, time_mocked):
|
164 |
self.maxDiff = None
|
165 |
|
tests/test_lambda_app.py
DELETED
@@ -1,232 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import time
|
3 |
-
import unittest
|
4 |
-
from unittest.mock import patch
|
5 |
-
|
6 |
-
from samgis_lisa_on_cuda import IS_AWS_LAMBDA
|
7 |
-
|
8 |
-
if IS_AWS_LAMBDA:
|
9 |
-
try:
|
10 |
-
from awslambdaric.lambda_context import LambdaContext
|
11 |
-
|
12 |
-
from samgis_lisa_on_cuda.io import wrappers_helpers
|
13 |
-
from wrappers import lambda_wrapper
|
14 |
-
from tests.local_tiles_http_server import LocalTilesHttpServer
|
15 |
-
|
16 |
-
|
17 |
-
class TestLambdaApp(unittest.TestCase):
|
18 |
-
@patch.object(time, "time")
|
19 |
-
@patch.object(lambda_wrapper, "samexporter_predict")
|
20 |
-
@patch.object(lambda_wrapper, "get_parsed_bbox_points")
|
21 |
-
@patch.object(lambda_wrapper, "get_parsed_request_body")
|
22 |
-
def test_lambda_handler_500(
|
23 |
-
self,
|
24 |
-
get_parsed_request_body_mocked,
|
25 |
-
get_parsed_bbox_points_mocked,
|
26 |
-
samexporter_predict_mocked,
|
27 |
-
time_mocked
|
28 |
-
):
|
29 |
-
from wrappers.lambda_wrapper import lambda_handler
|
30 |
-
|
31 |
-
time_mocked.return_value = 0
|
32 |
-
get_parsed_request_body_mocked.value = {}
|
33 |
-
get_parsed_bbox_points_mocked.return_value = {"bbox": "bbox_object", "prompt": "prompt_object",
|
34 |
-
"zoom": 1}
|
35 |
-
samexporter_predict_mocked.side_effect = ValueError("I raise a value error!")
|
36 |
-
|
37 |
-
event = {"body": {}, "version": 1.0}
|
38 |
-
lambda_context = LambdaContext(
|
39 |
-
invoke_id="test_invoke_id",
|
40 |
-
client_context=None,
|
41 |
-
cognito_identity=None,
|
42 |
-
epoch_deadline_time_in_ms=time.time()
|
43 |
-
)
|
44 |
-
expected_response_500 = '{"statusCode": 500, "header": {"Content-Type": "application/json"}, '
|
45 |
-
expected_response_500 += '"body": "{\\"duration_run\\": 0, \\"message\\": \\"Internal server error\\", '
|
46 |
-
expected_response_500 += '\\"request_id\\": \\"test_invoke_id\\"}", "isBase64Encoded": false}'
|
47 |
-
|
48 |
-
assert lambda_handler(event, lambda_context) == expected_response_500
|
49 |
-
|
50 |
-
|
51 |
-
@patch.object(time, "time")
|
52 |
-
@patch.object(lambda_wrapper, "get_parsed_request_body")
|
53 |
-
def test_lambda_handler_400(self, get_parsed_request_body_mocked, time_mocked):
|
54 |
-
from wrappers.lambda_wrapper import lambda_handler
|
55 |
-
|
56 |
-
time_mocked.return_value = 0
|
57 |
-
get_parsed_request_body_mocked.return_value = {}
|
58 |
-
|
59 |
-
event = {"body": {}, "version": 1.0}
|
60 |
-
lambda_context = LambdaContext(
|
61 |
-
invoke_id="test_invoke_id",
|
62 |
-
client_context=None,
|
63 |
-
cognito_identity=None,
|
64 |
-
epoch_deadline_time_in_ms=time.time()
|
65 |
-
)
|
66 |
-
|
67 |
-
assert lambda_handler(event, lambda_context) == (
|
68 |
-
'{"statusCode": 400, "header": {"Content-Type": "application/json"}, '
|
69 |
-
'"body": "{\\"duration_run\\": 0, \\"message\\": \\"Bad Request\\", '
|
70 |
-
'\\"request_id\\": \\"test_invoke_id\\"}", "isBase64Encoded": false}')
|
71 |
-
|
72 |
-
|
73 |
-
@patch.object(time, "time")
|
74 |
-
def test_lambda_handler_422(self, time_mocked):
|
75 |
-
from wrappers.lambda_wrapper import lambda_handler
|
76 |
-
|
77 |
-
time_mocked.return_value = 0
|
78 |
-
event = {"body": {}, "version": 1.0}
|
79 |
-
lambda_context = LambdaContext(
|
80 |
-
invoke_id="test_invoke_id",
|
81 |
-
client_context=None,
|
82 |
-
cognito_identity=None,
|
83 |
-
epoch_deadline_time_in_ms=time.time()
|
84 |
-
)
|
85 |
-
|
86 |
-
response_422 = lambda_handler(event, lambda_context)
|
87 |
-
expected_response_422 = '{"statusCode": 422, "header": {"Content-Type": "application/json"}, '
|
88 |
-
expected_response_422 += '"body": "{\\"duration_run\\": 0, \\"message\\": \\"Missing required parameter\\", '
|
89 |
-
expected_response_422 += '\\"request_id\\": \\"test_invoke_id\\"}", "isBase64Encoded": false}'
|
90 |
-
|
91 |
-
assert response_422 == expected_response_422
|
92 |
-
|
93 |
-
|
94 |
-
@patch.object(time, "time")
|
95 |
-
@patch.object(lambda_wrapper, "samexporter_predict")
|
96 |
-
@patch.object(lambda_wrapper, "get_response")
|
97 |
-
@patch.object(lambda_wrapper, "get_parsed_bbox_points")
|
98 |
-
@patch.object(lambda_wrapper, "get_parsed_request_body")
|
99 |
-
def test_lambda_handler_200_mocked(
|
100 |
-
self,
|
101 |
-
get_parsed_request_body_mocked,
|
102 |
-
get_parsed_bbox_points_mocked,
|
103 |
-
get_response_mocked,
|
104 |
-
samexporter_predict_mocked,
|
105 |
-
time_mocked
|
106 |
-
):
|
107 |
-
from wrappers.lambda_wrapper import lambda_handler
|
108 |
-
from tests import TEST_EVENTS_FOLDER
|
109 |
-
|
110 |
-
time_mocked.return_value = 0
|
111 |
-
get_parsed_request_body_mocked.value = {}
|
112 |
-
get_parsed_bbox_points_mocked.return_value = {"bbox": "bbox_object", "prompt": "prompt_object", "zoom": 1}
|
113 |
-
|
114 |
-
response_type = "200"
|
115 |
-
with open(TEST_EVENTS_FOLDER / "get_response.json") as tst_json_get_response:
|
116 |
-
get_response_io = json.load(tst_json_get_response)
|
117 |
-
|
118 |
-
input_200 = {
|
119 |
-
"bbox": {
|
120 |
-
"ne": {"lat": 38.03932961278458, "lng": 15.36808069832851},
|
121 |
-
"sw": {"lat": 37.455509218936974, "lng": 14.632807441554068}
|
122 |
-
},
|
123 |
-
"prompt": [{
|
124 |
-
"type": "point",
|
125 |
-
"data": {"lat": 37.0, "lng": 15.0},
|
126 |
-
"label": 0
|
127 |
-
}],
|
128 |
-
"zoom": 10,
|
129 |
-
"source_type": "OpenStreetMap.Mapnik",
|
130 |
-
"debug": True
|
131 |
-
}
|
132 |
-
|
133 |
-
samexporter_predict_output = get_response_io[response_type]["input"]
|
134 |
-
samexporter_predict_mocked.return_value = samexporter_predict_output
|
135 |
-
samexporter_predict_mocked.side_effect = None
|
136 |
-
get_response_mocked.return_value = get_response_io[response_type]["output"]
|
137 |
-
|
138 |
-
event = {"body": input_200, "version": 1.0}
|
139 |
-
|
140 |
-
lambda_context = LambdaContext(
|
141 |
-
invoke_id="test_invoke_id",
|
142 |
-
client_context=None,
|
143 |
-
cognito_identity=None,
|
144 |
-
epoch_deadline_time_in_ms=time.time()
|
145 |
-
)
|
146 |
-
|
147 |
-
response_200 = lambda_handler(event, lambda_context)
|
148 |
-
expected_response_200 = get_response_io[response_type]["output"]
|
149 |
-
print(f"types: response_200:{type(response_200)}, expected:{type(expected_response_200)}.")
|
150 |
-
assert response_200 == expected_response_200
|
151 |
-
|
152 |
-
|
153 |
-
@patch.object(wrappers_helpers, "get_url_tile")
|
154 |
-
def test_lambda_handler_200_real_single_multi_point(self, get_url_tile_mocked):
|
155 |
-
import xyzservices
|
156 |
-
import shapely
|
157 |
-
|
158 |
-
from wrappers.lambda_wrapper import lambda_handler
|
159 |
-
from tests import LOCAL_URL_TILE, TEST_EVENTS_FOLDER
|
160 |
-
|
161 |
-
local_tile_provider = xyzservices.TileProvider(name="local_tile_provider", url=LOCAL_URL_TILE,
|
162 |
-
attribution="")
|
163 |
-
get_url_tile_mocked.return_value = local_tile_provider
|
164 |
-
fn_name = "lambda_handler"
|
165 |
-
invoke_id = "test_invoke_id"
|
166 |
-
|
167 |
-
for json_filename in [
|
168 |
-
"single_point",
|
169 |
-
"multi_prompt",
|
170 |
-
"single_rectangle"
|
171 |
-
]:
|
172 |
-
with open(TEST_EVENTS_FOLDER / f"{fn_name}_{json_filename}.json") as tst_json:
|
173 |
-
inputs_outputs = json.load(tst_json)
|
174 |
-
lambda_context = LambdaContext(
|
175 |
-
invoke_id=invoke_id,
|
176 |
-
client_context=None,
|
177 |
-
cognito_identity=None,
|
178 |
-
epoch_deadline_time_in_ms=time.time()
|
179 |
-
)
|
180 |
-
expected_response_dict = inputs_outputs["output"]
|
181 |
-
listen_port = 8000
|
182 |
-
expected_response_body = json.loads(expected_response_dict["body"])
|
183 |
-
|
184 |
-
with LocalTilesHttpServer.http_server("localhost", listen_port, directory=TEST_EVENTS_FOLDER):
|
185 |
-
input_event = inputs_outputs["input"]
|
186 |
-
input_event_body = json.loads(input_event["body"])
|
187 |
-
input_event["body"] = json.dumps(input_event_body)
|
188 |
-
response = lambda_handler(event=input_event, context=lambda_context)
|
189 |
-
|
190 |
-
response_dict = json.loads(response)
|
191 |
-
assert response_dict["statusCode"] == 200
|
192 |
-
body_dict = json.loads(response_dict["body"])
|
193 |
-
assert body_dict["n_predictions"] == 1
|
194 |
-
assert body_dict["request_id"] == invoke_id
|
195 |
-
assert body_dict["message"] == "ok"
|
196 |
-
assert body_dict["n_shapes_geojson"] == expected_response_body["n_shapes_geojson"]
|
197 |
-
|
198 |
-
output_geojson = shapely.from_geojson(body_dict["geojson"])
|
199 |
-
print("output_geojson::", type(output_geojson))
|
200 |
-
assert isinstance(output_geojson, shapely.GeometryCollection)
|
201 |
-
assert len(output_geojson.geoms) == expected_response_body["n_shapes_geojson"]
|
202 |
-
|
203 |
-
|
204 |
-
def test_debug(self):
|
205 |
-
from wrappers.lambda_wrapper import lambda_handler
|
206 |
-
|
207 |
-
input_event = {
|
208 |
-
'bbox': {
|
209 |
-
'ne': {'lat': 46.302592089330524, 'lng': 9.49493408203125},
|
210 |
-
'sw': {'lat': 46.14011755129237, 'lng': 9.143371582031252}},
|
211 |
-
'prompt': [
|
212 |
-
{'id': 166, 'type': 'point', 'data': {'lat': 46.18244521829928, 'lng': 9.418544769287111},
|
213 |
-
'label': 1}
|
214 |
-
],
|
215 |
-
'zoom': 12, 'source_type': 'OpenStreetMap'
|
216 |
-
}
|
217 |
-
lambda_context = LambdaContext(
|
218 |
-
invoke_id="test_invoke_id",
|
219 |
-
client_context=None,
|
220 |
-
cognito_identity=None,
|
221 |
-
epoch_deadline_time_in_ms=time.time()
|
222 |
-
)
|
223 |
-
response = lambda_handler(event=input_event, context=lambda_context)
|
224 |
-
print(response)
|
225 |
-
except ModuleNotFoundError as mnfe:
|
226 |
-
print("missing awslambdaric...")
|
227 |
-
raise mnfe
|
228 |
-
|
229 |
-
|
230 |
-
if __name__ == '__main__':
|
231 |
-
if IS_AWS_LAMBDA:
|
232 |
-
unittest.main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wrappers/__init__.py
DELETED
File without changes
|
wrappers/fastapi_wrapper.py
DELETED
@@ -1,273 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
import pathlib
|
4 |
-
import uuid
|
5 |
-
|
6 |
-
from fastapi.templating import Jinja2Templates
|
7 |
-
import uvicorn
|
8 |
-
from fastapi import FastAPI, HTTPException, Request, status
|
9 |
-
from fastapi.exceptions import RequestValidationError
|
10 |
-
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
11 |
-
from fastapi.staticfiles import StaticFiles
|
12 |
-
from pydantic import ValidationError
|
13 |
-
|
14 |
-
from samgis_lisa_on_cuda import PROJECT_ROOT_FOLDER, WORKDIR
|
15 |
-
from samgis_lisa_on_cuda.utilities.type_hints import ApiRequestBody, StringPromptApiRequestBody
|
16 |
-
from samgis_core.utilities.fastapi_logger import setup_logging
|
17 |
-
|
18 |
-
|
19 |
-
app_logger = setup_logging(debug=True)
|
20 |
-
app = FastAPI()
|
21 |
-
|
22 |
-
|
23 |
-
@app.middleware("http")
|
24 |
-
async def request_middleware(request, call_next):
|
25 |
-
request_id = str(uuid.uuid4())
|
26 |
-
with app_logger.contextualize(request_id=request_id):
|
27 |
-
app_logger.info("Request started")
|
28 |
-
|
29 |
-
try:
|
30 |
-
response = await call_next(request)
|
31 |
-
|
32 |
-
except Exception as ex:
|
33 |
-
app_logger.error(f"Request failed: {ex}")
|
34 |
-
response = JSONResponse(content={"success": False}, status_code=500)
|
35 |
-
|
36 |
-
finally:
|
37 |
-
response.headers["X-Request-ID"] = request_id
|
38 |
-
app_logger.info("Request ended")
|
39 |
-
|
40 |
-
return response
|
41 |
-
|
42 |
-
|
43 |
-
@app.post("/post_test_dictlist")
|
44 |
-
def post_test_dictlist2(request_input: ApiRequestBody) -> JSONResponse:
|
45 |
-
from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt
|
46 |
-
|
47 |
-
request_body = get_parsed_bbox_points_with_dictlist_prompt(request_input)
|
48 |
-
app_logger.info(f"request_body:{request_body}.")
|
49 |
-
return JSONResponse(
|
50 |
-
status_code=200,
|
51 |
-
content=request_body
|
52 |
-
)
|
53 |
-
|
54 |
-
|
55 |
-
@app.get("/health")
|
56 |
-
async def health() -> JSONResponse:
|
57 |
-
import importlib.metadata
|
58 |
-
from importlib.metadata import PackageNotFoundError
|
59 |
-
|
60 |
-
try:
|
61 |
-
core_version = importlib.metadata.version('samgis_core')
|
62 |
-
lisa_on_cuda_version = importlib.metadata.version('lisa-on-cuda')
|
63 |
-
samgis_lisa_on_cuda_version = importlib.metadata.version('samgis-lisa-on-cuda')
|
64 |
-
except PackageNotFoundError as pe:
|
65 |
-
app_logger.error(f"pe:{pe}.")
|
66 |
-
samgis_lisa_on_cuda_version = "0.0.0"
|
67 |
-
|
68 |
-
msg = "still alive, "
|
69 |
-
msg += f"""version:{samgis_lisa_on_cuda_version}, core version:{core_version},"""
|
70 |
-
msg += f"""lisa-on-cuda version:{lisa_on_cuda_version},"""
|
71 |
-
|
72 |
-
app_logger.info(msg)
|
73 |
-
return JSONResponse(status_code=200, content={"msg": "still alive..."})
|
74 |
-
|
75 |
-
|
76 |
-
@app.post("/post_test_string")
|
77 |
-
def post_test_string(request_input: StringPromptApiRequestBody) -> JSONResponse:
|
78 |
-
from lisa_on_cuda.utils import app_helpers
|
79 |
-
from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_bbox_points_with_string_prompt
|
80 |
-
|
81 |
-
request_body = get_parsed_bbox_points_with_string_prompt(request_input)
|
82 |
-
app_logger.info(f"request_body:{request_body}.")
|
83 |
-
custom_args = app_helpers.parse_args([])
|
84 |
-
request_body["content"] = {**request_body, "precision": str(custom_args.precision)}
|
85 |
-
return JSONResponse(
|
86 |
-
status_code=200,
|
87 |
-
content=request_body
|
88 |
-
)
|
89 |
-
|
90 |
-
|
91 |
-
@app.post("/infer_lisa")
|
92 |
-
def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse:
|
93 |
-
from samgis_lisa_on_cuda.prediction_api import lisa
|
94 |
-
from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_bbox_points_with_string_prompt, get_source_name
|
95 |
-
|
96 |
-
app_logger.info("starting lisa inference request...")
|
97 |
-
|
98 |
-
try:
|
99 |
-
import time
|
100 |
-
|
101 |
-
time_start_run = time.time()
|
102 |
-
body_request = get_parsed_bbox_points_with_string_prompt(request_input)
|
103 |
-
app_logger.info(f"lisa body_request:{body_request}.")
|
104 |
-
app_logger.info(f"lisa module:{lisa}.")
|
105 |
-
try:
|
106 |
-
source_name = get_source_name(request_input.source_type)
|
107 |
-
app_logger.info(f"source_name = {source_name}.")
|
108 |
-
output = lisa.lisa_predict(
|
109 |
-
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
110 |
-
source=body_request["source"], source_name=source_name
|
111 |
-
)
|
112 |
-
duration_run = time.time() - time_start_run
|
113 |
-
app_logger.info(f"duration_run:{duration_run}.")
|
114 |
-
body = {
|
115 |
-
"duration_run": duration_run,
|
116 |
-
"output": output
|
117 |
-
}
|
118 |
-
return JSONResponse(status_code=200, content={"body": json.dumps(body)})
|
119 |
-
except Exception as inference_exception:
|
120 |
-
import subprocess
|
121 |
-
project_root_folder_content = subprocess.run(
|
122 |
-
f"ls -l {PROJECT_ROOT_FOLDER}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE
|
123 |
-
)
|
124 |
-
app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.")
|
125 |
-
workdir_folder_content = subprocess.run(
|
126 |
-
f"ls -l {WORKDIR}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE
|
127 |
-
)
|
128 |
-
app_logger.error(f"workdir folder 'ls -l' command output: {workdir_folder_content.stdout}.")
|
129 |
-
app_logger.error(f"inference error:{inference_exception}.")
|
130 |
-
raise HTTPException(
|
131 |
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference")
|
132 |
-
except ValidationError as va1:
|
133 |
-
app_logger.error(f"validation error: {str(va1)}.")
|
134 |
-
raise ValidationError("Unprocessable Entity")
|
135 |
-
|
136 |
-
|
137 |
-
@app.post("/infer_samgis")
|
138 |
-
def infer_samgis(request_input: ApiRequestBody) -> JSONResponse:
|
139 |
-
from samgis_lisa_on_cuda.prediction_api import predictors
|
140 |
-
from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt, get_source_name
|
141 |
-
|
142 |
-
app_logger.info("starting plain samgis inference request...")
|
143 |
-
|
144 |
-
try:
|
145 |
-
import time
|
146 |
-
|
147 |
-
time_start_run = time.time()
|
148 |
-
body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
|
149 |
-
app_logger.info(f"body_request:{body_request}.")
|
150 |
-
try:
|
151 |
-
source_name = get_source_name(request_input.source_type)
|
152 |
-
app_logger.info(f"source_name = {source_name}.")
|
153 |
-
output = predictors.samexporter_predict(
|
154 |
-
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"],
|
155 |
-
source=body_request["source"], source_name=source_name
|
156 |
-
)
|
157 |
-
duration_run = time.time() - time_start_run
|
158 |
-
app_logger.info(f"duration_run:{duration_run}.")
|
159 |
-
body = {
|
160 |
-
"duration_run": duration_run,
|
161 |
-
"output": output
|
162 |
-
}
|
163 |
-
return JSONResponse(status_code=200, content={"body": json.dumps(body)})
|
164 |
-
except Exception as inference_exception:
|
165 |
-
import subprocess
|
166 |
-
project_root_folder_content = subprocess.run(
|
167 |
-
f"ls -l {PROJECT_ROOT_FOLDER}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE
|
168 |
-
)
|
169 |
-
app_logger.error(f"project_root folder 'ls -l' command output: {project_root_folder_content.stdout}.")
|
170 |
-
workdir_folder_content = subprocess.run(
|
171 |
-
f"ls -l {WORKDIR}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE
|
172 |
-
)
|
173 |
-
app_logger.error(f"workdir folder 'ls -l' command output: {workdir_folder_content.stdout}.")
|
174 |
-
app_logger.error(f"inference error:{inference_exception}.")
|
175 |
-
raise HTTPException(
|
176 |
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference")
|
177 |
-
except ValidationError as va1:
|
178 |
-
app_logger.error(f"validation error: {str(va1)}.")
|
179 |
-
raise ValidationError("Unprocessable Entity")
|
180 |
-
|
181 |
-
|
182 |
-
@app.exception_handler(RequestValidationError)
|
183 |
-
async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
|
184 |
-
app_logger.error(f"exception errors: {exc.errors()}.")
|
185 |
-
app_logger.error(f"exception body: {exc.body}.")
|
186 |
-
headers = request.headers.items()
|
187 |
-
app_logger.error(f'request header: {dict(headers)}.')
|
188 |
-
params = request.query_params.items()
|
189 |
-
app_logger.error(f'request query params: {dict(params)}.')
|
190 |
-
return JSONResponse(
|
191 |
-
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
192 |
-
content={"msg": "Error - Unprocessable Entity"}
|
193 |
-
)
|
194 |
-
|
195 |
-
|
196 |
-
@app.exception_handler(HTTPException)
|
197 |
-
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
198 |
-
app_logger.error(f"exception: {str(exc)}.")
|
199 |
-
headers = request.headers.items()
|
200 |
-
app_logger.error(f'request header: {dict(headers)}.')
|
201 |
-
params = request.query_params.items()
|
202 |
-
app_logger.error(f'request query params: {dict(params)}.')
|
203 |
-
return JSONResponse(
|
204 |
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
205 |
-
content={"msg": "Error - Internal Server Error"}
|
206 |
-
)
|
207 |
-
|
208 |
-
|
209 |
-
write_tmp_on_disk = os.getenv("WRITE_TMP_ON_DISK", "")
|
210 |
-
app_logger.info(f"write_tmp_on_disk:{write_tmp_on_disk}.")
|
211 |
-
if bool(write_tmp_on_disk):
|
212 |
-
try:
|
213 |
-
path_write_tmp_on_disk = pathlib.Path(write_tmp_on_disk)
|
214 |
-
try:
|
215 |
-
pathlib.Path.unlink(path_write_tmp_on_disk, missing_ok=True)
|
216 |
-
except (IsADirectoryError, PermissionError, OSError) as err:
|
217 |
-
app_logger.error(f"{err} while removing old write_tmp_on_disk:{write_tmp_on_disk}.")
|
218 |
-
app_logger.error(f"is file?{path_write_tmp_on_disk.is_file()}.")
|
219 |
-
app_logger.error(f"is symlink?{path_write_tmp_on_disk.is_symlink()}.")
|
220 |
-
app_logger.error(f"is folder?{path_write_tmp_on_disk.is_dir()}.")
|
221 |
-
os.makedirs(write_tmp_on_disk, exist_ok=True)
|
222 |
-
app.mount("/vis_output", StaticFiles(directory=write_tmp_on_disk), name="vis_output")
|
223 |
-
except RuntimeError as rerr:
|
224 |
-
app_logger.error(f"{rerr} while loading the folder write_tmp_on_disk:{write_tmp_on_disk}...")
|
225 |
-
raise rerr
|
226 |
-
templates = Jinja2Templates(directory=WORKDIR / "static")
|
227 |
-
|
228 |
-
|
229 |
-
@app.get("/vis_output", response_class=HTMLResponse)
|
230 |
-
def list_files(request: Request):
|
231 |
-
|
232 |
-
files = os.listdir(write_tmp_on_disk)
|
233 |
-
files_paths = sorted([f"{request.url._url}/{f}" for f in files])
|
234 |
-
print(files_paths)
|
235 |
-
return templates.TemplateResponse(
|
236 |
-
"list_files.html", {"request": request, "files": files_paths}
|
237 |
-
)
|
238 |
-
|
239 |
-
|
240 |
-
# important: the index() function and the app.mount MUST be at the end
|
241 |
-
# samgis.html
|
242 |
-
app.mount("/samgis", StaticFiles(directory=WORKDIR / "static" / "dist", html=True), name="samgis")
|
243 |
-
|
244 |
-
|
245 |
-
@app.get("/samgis")
|
246 |
-
async def samgis() -> FileResponse:
|
247 |
-
return FileResponse(path=WORKDIR / "static" / "dist" / "samgis.html", media_type="text/html")
|
248 |
-
|
249 |
-
|
250 |
-
# lisa.html
|
251 |
-
app.mount("/lisa", StaticFiles(directory=WORKDIR / "static" / "dist", html=True), name="lisa")
|
252 |
-
|
253 |
-
|
254 |
-
@app.get("/lisa")
|
255 |
-
async def lisa() -> FileResponse:
|
256 |
-
return FileResponse(path=WORKDIR / "static" / "dist" / "lisa.html", media_type="text/html")
|
257 |
-
|
258 |
-
|
259 |
-
# index.html (lisa.html copy)
|
260 |
-
app.mount("/", StaticFiles(directory=WORKDIR / "static" / "dist", html=True), name="index")
|
261 |
-
|
262 |
-
|
263 |
-
@app.get("/")
|
264 |
-
async def index() -> FileResponse:
|
265 |
-
return FileResponse(path=WORKDIR / "static" / "dist" / "index.html", media_type="text/html")
|
266 |
-
|
267 |
-
|
268 |
-
if __name__ == '__main__':
|
269 |
-
try:
|
270 |
-
uvicorn.run(host="0.0.0.0", port=7860, app=app)
|
271 |
-
except Exception as e:
|
272 |
-
app_logger.error("e:", e)
|
273 |
-
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wrappers/lambda_wrapper.py
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
"""Lambda entry point"""
|
2 |
-
from http import HTTPStatus
|
3 |
-
from typing import Dict
|
4 |
-
|
5 |
-
from aws_lambda_powertools.utilities.typing import LambdaContext
|
6 |
-
from pydantic import ValidationError
|
7 |
-
|
8 |
-
from samgis_lisa_on_cuda import app_logger
|
9 |
-
from samgis_lisa_on_cuda.io.wrappers_helpers import get_parsed_request_body, get_parsed_bbox_points_with_dictlist_prompt, get_response
|
10 |
-
from samgis_lisa_on_cuda.prediction_api.predictors import samexporter_predict
|
11 |
-
|
12 |
-
|
13 |
-
def lambda_handler(event: Dict, context: LambdaContext) -> str:
|
14 |
-
"""
|
15 |
-
Handle the request for the serverless backend and return the response
|
16 |
-
(success or a type of error based on the exception raised).
|
17 |
-
|
18 |
-
Args:
|
19 |
-
event: request content
|
20 |
-
context: request context
|
21 |
-
|
22 |
-
Returns:
|
23 |
-
json response from get_response() function
|
24 |
-
|
25 |
-
"""
|
26 |
-
from time import time
|
27 |
-
app_logger.info(f"start with aws_request_id:{context.aws_request_id}.")
|
28 |
-
start_time = time()
|
29 |
-
|
30 |
-
if "version" in event:
|
31 |
-
app_logger.info(f"event version: {event['version']}.")
|
32 |
-
|
33 |
-
try:
|
34 |
-
app_logger.info("try get_parsed_event...")
|
35 |
-
request_input = get_parsed_request_body(event)
|
36 |
-
app_logger.info("event parsed: ok")
|
37 |
-
body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input)
|
38 |
-
app_logger.info(f"body_request => {type(body_request)}, {body_request}.")
|
39 |
-
|
40 |
-
try:
|
41 |
-
body_response = samexporter_predict(
|
42 |
-
body_request["bbox"], body_request["prompt"], body_request["zoom"], source=body_request["source"]
|
43 |
-
)
|
44 |
-
app_logger.info(f"output body_response length:{len(body_response)}.")
|
45 |
-
app_logger.debug(f"output body_response:{body_response}.")
|
46 |
-
response = get_response(HTTPStatus.OK.value, start_time, context.aws_request_id, body_response)
|
47 |
-
except Exception as ex2:
|
48 |
-
app_logger.exception(f"exception2:{ex2}.", exc_info=True)
|
49 |
-
response = get_response(HTTPStatus.INTERNAL_SERVER_ERROR.value, start_time, context.aws_request_id, {})
|
50 |
-
except ValidationError as va1:
|
51 |
-
app_logger.exception(f"ValidationError:{va1}.", exc_info=True)
|
52 |
-
response = get_response(HTTPStatus.UNPROCESSABLE_ENTITY.value, start_time, context.aws_request_id, {})
|
53 |
-
except Exception as ex1:
|
54 |
-
app_logger.exception(f"exception1:{ex1}.", exc_info=True)
|
55 |
-
response = get_response(HTTPStatus.BAD_REQUEST.value, start_time, context.aws_request_id, {})
|
56 |
-
|
57 |
-
app_logger.debug(f"response_dumped:{response}...")
|
58 |
-
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|