alessandro trinca tornidor
commited on
Commit
•
c06a3f6
1
Parent(s):
09433ba
[test] add fastapi app test cases
Browse files- .coveragerc +1 -1
- .idea/other.xml +6 -0
- pyproject.toml +1 -0
- samgis/__init__.py +1 -27
- samgis/io/wrappers_helpers.py +16 -10
- samgis/utilities/fastapi_logger.py +24 -0
- tests/io/test_lambda_helpers.py +5 -5
- tests/test_fastapi_app.py +172 -0
- tests/{test_app.py → test_lambda_app.py} +9 -4
- wrappers/fastapi_wrapper.py +16 -12
.coveragerc
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
[run]
|
2 |
source = samgis
|
3 |
-
omit = ./venv
|
4 |
|
5 |
[report]
|
6 |
omit = ./venv/*,*tests*,*apps.py,*manage.py,*__init__.py,*migrations*,*asgi*,*wsgi*,*admin.py,*urls.py,./tests/*
|
|
|
1 |
[run]
|
2 |
source = samgis
|
3 |
+
omit = ./venv/*,__version__.py,*tests*,*apps.py,*manage.py,*__init__.py,*migrations*,*asgi*,*wsgi*,*admin.py,*urls.py,./tests/*
|
4 |
|
5 |
[report]
|
6 |
omit = ./venv/*,*tests*,*apps.py,*manage.py,*__init__.py,*migrations*,*asgi*,*wsgi*,*admin.py,*urls.py,./tests/*
|
.idea/other.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="PySciProjectComponent">
|
4 |
+
<option name="PY_INTERACTIVE_PLOTS_SUGGESTED" value="true" />
|
5 |
+
</component>
|
6 |
+
</project>
|
pyproject.toml
CHANGED
@@ -36,6 +36,7 @@ optional = true
|
|
36 |
pytest = "^7.4.3"
|
37 |
pytest-cov = "^4.1.0"
|
38 |
python-dotenv = "^1.0.0"
|
|
|
39 |
|
40 |
[tool.poetry.group.docs]
|
41 |
optional = true
|
|
|
36 |
pytest = "^7.4.3"
|
37 |
pytest-cov = "^4.1.0"
|
38 |
python-dotenv = "^1.0.0"
|
39 |
+
httpx = "^0.26.0"
|
40 |
|
41 |
[tool.poetry.group.docs]
|
42 |
optional = true
|
samgis/__init__.py
CHANGED
@@ -5,7 +5,6 @@ from pathlib import Path
|
|
5 |
|
6 |
from samgis.utilities.constants import SERVICE_NAME
|
7 |
|
8 |
-
|
9 |
PROJECT_ROOT_FOLDER = Path(globals().get("__file__", "./_")).absolute().parent.parent
|
10 |
MODEL_FOLDER = Path(PROJECT_ROOT_FOLDER / "machine_learning_models")
|
11 |
try:
|
@@ -13,29 +12,4 @@ try:
|
|
13 |
|
14 |
app_logger = Logger(service=SERVICE_NAME)
|
15 |
except ModuleNotFoundError:
|
16 |
-
import
|
17 |
-
|
18 |
-
def setup_logging(debug: bool = False, formatter: str = "{time} - {level} - ({extra[request_id]}) {message} "
|
19 |
-
) -> loguru.logger:
|
20 |
-
"""
|
21 |
-
Create a logging instance with log string formatter.
|
22 |
-
|
23 |
-
Args:
|
24 |
-
debug: logging debug argument
|
25 |
-
formatter: log string formatter
|
26 |
-
|
27 |
-
Returns:
|
28 |
-
Logger
|
29 |
-
|
30 |
-
"""
|
31 |
-
import sys
|
32 |
-
|
33 |
-
logger = loguru.logger
|
34 |
-
logger.remove()
|
35 |
-
level_logger = "DEBUG" if debug else "INFO"
|
36 |
-
logger.add(sys.stdout, format=formatter, level=level_logger)
|
37 |
-
logger.info(f"type_logger:{type(logger)}, logger:{logger}.")
|
38 |
-
return logger
|
39 |
-
|
40 |
-
|
41 |
-
app_logger = setup_logging(debug=True)
|
|
|
5 |
|
6 |
from samgis.utilities.constants import SERVICE_NAME
|
7 |
|
|
|
8 |
PROJECT_ROOT_FOLDER = Path(globals().get("__file__", "./_")).absolute().parent.parent
|
9 |
MODEL_FOLDER = Path(PROJECT_ROOT_FOLDER / "machine_learning_models")
|
10 |
try:
|
|
|
12 |
|
13 |
app_logger = Logger(service=SERVICE_NAME)
|
14 |
except ModuleNotFoundError:
|
15 |
+
from wrappers.fastapi_wrapper import app_logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
samgis/io/wrappers_helpers.py
CHANGED
@@ -165,16 +165,22 @@ nextzen_terrain_rgb = TileProvider(
|
|
165 |
|
166 |
|
167 |
def get_url_tile(source_type: str):
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
|
180 |
def check_source_type_is_terrain(source: str | TileProvider):
|
|
|
165 |
|
166 |
|
167 |
def get_url_tile(source_type: str):
|
168 |
+
try:
|
169 |
+
match source_type.lower():
|
170 |
+
case TmsDefaultProvidersNames.DEFAULT_TILES_NAME_SHORT:
|
171 |
+
return providers.query_name(TmsDefaultProvidersNames.DEFAULT_TILES_NAME)
|
172 |
+
case TmsTerrainProvidersNames.MAPBOX_TERRAIN_TILES_NAME:
|
173 |
+
return mapbox_terrain_rgb
|
174 |
+
case TmsTerrainProvidersNames.NEXTZEN_TERRAIN_TILES_NAME:
|
175 |
+
app_logger.info("nextzen_terrain_rgb:", nextzen_terrain_rgb)
|
176 |
+
return nextzen_terrain_rgb
|
177 |
+
|
178 |
+
return providers.query_name(source_type)
|
179 |
+
except ValueError as ve:
|
180 |
+
from pydantic_core import ValidationError
|
181 |
+
|
182 |
+
app_logger.error("ve:", str(ve))
|
183 |
+
raise ValidationError(ve)
|
184 |
|
185 |
|
186 |
def check_source_type_is_terrain(source: str | TileProvider):
|
samgis/utilities/fastapi_logger.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import loguru
|
2 |
+
|
3 |
+
|
4 |
+
def setup_logging(debug: bool = False, formatter: str = "{time} - {level} - ({extra[request_id]}) {message} "
|
5 |
+
) -> loguru.logger:
|
6 |
+
"""
|
7 |
+
Create a logging instance with log string formatter.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
debug: logging debug argument
|
11 |
+
formatter: log string formatter
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
Logger
|
15 |
+
|
16 |
+
"""
|
17 |
+
import sys
|
18 |
+
|
19 |
+
logger = loguru.logger
|
20 |
+
logger.remove()
|
21 |
+
level_logger = "DEBUG" if debug else "INFO"
|
22 |
+
logger.add(sys.stdout, format=formatter, level=level_logger)
|
23 |
+
logger.info(f"type_logger:{type(logger)}, logger:{logger}.")
|
24 |
+
return logger
|
tests/io/test_lambda_helpers.py
CHANGED
@@ -3,8 +3,8 @@ import time
|
|
3 |
from http import HTTPStatus
|
4 |
from unittest.mock import patch
|
5 |
|
6 |
-
from samgis.io import
|
7 |
-
from samgis.io.
|
8 |
from samgis.utilities.type_hints import ApiRequestBody
|
9 |
from samgis.utilities import utilities
|
10 |
from tests import TEST_EVENTS_FOLDER
|
@@ -89,10 +89,10 @@ def test_get_parsed_request_body():
|
|
89 |
assert output == ApiRequestBody.model_validate(expected_output_dict)
|
90 |
|
91 |
|
92 |
-
@patch.object(
|
93 |
def test_get_url_tile(providers_mocked):
|
94 |
import xyzservices
|
95 |
-
from samgis.io.
|
96 |
|
97 |
from tests import LOCAL_URL_TILE
|
98 |
|
@@ -112,7 +112,7 @@ def test_get_url_tile(providers_mocked):
|
|
112 |
|
113 |
|
114 |
def test_get_url_tile_real():
|
115 |
-
from samgis.io.
|
116 |
|
117 |
assert get_url_tile("OpenStreetMap") == {
|
118 |
'url': 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', 'max_zoom': 19,
|
|
|
3 |
from http import HTTPStatus
|
4 |
from unittest.mock import patch
|
5 |
|
6 |
+
from samgis.io import wrappers_helpers
|
7 |
+
from samgis.io.wrappers_helpers import get_parsed_bbox_points, get_parsed_request_body, get_response
|
8 |
from samgis.utilities.type_hints import ApiRequestBody
|
9 |
from samgis.utilities import utilities
|
10 |
from tests import TEST_EVENTS_FOLDER
|
|
|
89 |
assert output == ApiRequestBody.model_validate(expected_output_dict)
|
90 |
|
91 |
|
92 |
+
@patch.object(wrappers_helpers, "providers")
|
93 |
def test_get_url_tile(providers_mocked):
|
94 |
import xyzservices
|
95 |
+
from samgis.io.wrappers_helpers import get_url_tile
|
96 |
|
97 |
from tests import LOCAL_URL_TILE
|
98 |
|
|
|
112 |
|
113 |
|
114 |
def test_get_url_tile_real():
|
115 |
+
from samgis.io.wrappers_helpers import get_url_tile
|
116 |
|
117 |
assert get_url_tile("OpenStreetMap") == {
|
118 |
'url': 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', 'max_zoom': 19,
|
tests/test_fastapi_app.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
import unittest
|
4 |
+
from unittest.mock import patch
|
5 |
+
|
6 |
+
from fastapi.testclient import TestClient
|
7 |
+
|
8 |
+
from samgis import PROJECT_ROOT_FOLDER
|
9 |
+
from tests import TEST_EVENTS_FOLDER
|
10 |
+
from wrappers import fastapi_wrapper
|
11 |
+
from wrappers.fastapi_wrapper import app
|
12 |
+
|
13 |
+
client = TestClient(app)
|
14 |
+
source = {
|
15 |
+
'url': 'https://tile.openstreetmap.org/{z}/{x}/{y}.png', 'max_zoom': 19,
|
16 |
+
'html_attribution': '© <a href="https://www.openstreetmap.org/copyright">OpenStreetMap</a> contributors',
|
17 |
+
'attribution': '(C) OpenStreetMap contributors', 'name': 'OpenStreetMap.Mapnik'
|
18 |
+
}
|
19 |
+
event = {
|
20 |
+
"bbox": {
|
21 |
+
"ne": {"lat": 39.036252959636606, "lng": 15.040283203125002},
|
22 |
+
"sw": {"lat": 38.302869955150044, "lng": 13.634033203125002}
|
23 |
+
},
|
24 |
+
"prompt": [{"type": "point", "data": {"lat": 38.48542007717153, "lng": 14.921846904165468}, "label": 0}],
|
25 |
+
"zoom": 10, "source_type": "OpenStreetMap"
|
26 |
+
}
|
27 |
+
response_bodies_post_test = {
|
28 |
+
"single_point": {
|
29 |
+
'bbox': [[39.036252959636606, 15.040283203125002], [38.302869955150044, 13.634033203125002]],
|
30 |
+
'prompt': [{'type': 'point', 'label': 0, 'data': [937, 514]}], 'zoom': 10,
|
31 |
+
'source': source
|
32 |
+
},
|
33 |
+
"multi_prompt": {
|
34 |
+
'bbox': [[39.011714588834074, 15.093841552734377], [38.278078995562105, 13.687591552734377]],
|
35 |
+
'prompt': [
|
36 |
+
{'type': 'point', 'label': 1, 'data': [839, 421]},
|
37 |
+
{'type': 'point', 'label': 1, 'data': [906, 489]},
|
38 |
+
{'type': 'point', 'label': 1, 'data': [936, 580]}
|
39 |
+
], 'zoom': 10,
|
40 |
+
'source': source
|
41 |
+
},
|
42 |
+
"single_rectangle": {
|
43 |
+
'bbox': [[39.011714588834074, 15.093841552734377], [38.278078995562105, 13.687591552734377]],
|
44 |
+
'prompt': [{'type': 'rectangle', 'data': [875, 445, 951, 538]}], 'zoom': 10,
|
45 |
+
'source': source
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
class TestFastapiApp(unittest.TestCase):
|
51 |
+
def test_fastapi_handler_health_200(self):
|
52 |
+
response = client.get("/health")
|
53 |
+
assert response.status_code == 200
|
54 |
+
body = response.json()
|
55 |
+
assert body == {"msg": "still alive..."}
|
56 |
+
|
57 |
+
def test_fastapi_handler_post_test_200(self):
|
58 |
+
fn_name = "lambda_handler"
|
59 |
+
for json_filename in [
|
60 |
+
"single_point",
|
61 |
+
"multi_prompt",
|
62 |
+
"single_rectangle"
|
63 |
+
]:
|
64 |
+
with open(TEST_EVENTS_FOLDER / f"{fn_name}_{json_filename}.json") as tst_json:
|
65 |
+
inputs_outputs = json.load(tst_json)
|
66 |
+
input_body = json.loads(inputs_outputs["input"]["body"])
|
67 |
+
response = client.post("/post_test", json=input_body)
|
68 |
+
assert response.status_code == 200
|
69 |
+
response_body = response.json()
|
70 |
+
assert response_body == response_bodies_post_test[json_filename]
|
71 |
+
|
72 |
+
def test_fastapi_handler_post_test_422(self):
|
73 |
+
response = client.post("/post_test", json={})
|
74 |
+
assert response.status_code == 422
|
75 |
+
body = response.json()
|
76 |
+
assert body == {'msg': 'Error - Unprocessable Entity'}
|
77 |
+
|
78 |
+
def test_index(self):
|
79 |
+
import subprocess
|
80 |
+
|
81 |
+
subprocess.run(["pnpm", "build"], cwd=PROJECT_ROOT_FOLDER / "static")
|
82 |
+
subprocess.run(["pnpm", "tailwindcss", "-i", "./src/input.css", "-o", "./dist/output.css"],
|
83 |
+
cwd=PROJECT_ROOT_FOLDER / "static")
|
84 |
+
response = client.get("/")
|
85 |
+
assert response.status_code == 200
|
86 |
+
html_body = response.read().decode("utf-8")
|
87 |
+
assert "html" in html_body
|
88 |
+
assert "head" in html_body
|
89 |
+
assert "body" in html_body
|
90 |
+
|
91 |
+
def test_404(self):
|
92 |
+
response = client.get("/404")
|
93 |
+
assert response.status_code == 404
|
94 |
+
|
95 |
+
def test_infer_samgis_422(self):
|
96 |
+
response = client.post("/infer_samgis", json={})
|
97 |
+
print("response.status_code:", response.status_code)
|
98 |
+
assert response.status_code == 422
|
99 |
+
body_loaded = response.json()
|
100 |
+
print("response.body_loaded:", body_loaded)
|
101 |
+
assert body_loaded == {"msg": "Error - Unprocessable Entity"}
|
102 |
+
|
103 |
+
def test_infer_samgis_middleware_500(self):
|
104 |
+
from copy import deepcopy
|
105 |
+
local_event = deepcopy(event)
|
106 |
+
|
107 |
+
local_event["source_type"] = "source_fake"
|
108 |
+
response = client.post("/infer_samgis", json=local_event)
|
109 |
+
print("response.status_code:", response.status_code)
|
110 |
+
assert response.status_code == 500
|
111 |
+
body_loaded = response.json()
|
112 |
+
print("response.body_loaded:", body_loaded)
|
113 |
+
assert body_loaded == {'success': False}
|
114 |
+
|
115 |
+
@patch.object(time, "time")
|
116 |
+
@patch.object(fastapi_wrapper, "samexporter_predict")
|
117 |
+
def test_infer_samgis_500(self, samexporter_predict_mocked, time_mocked):
|
118 |
+
time_mocked.return_value = 0
|
119 |
+
samexporter_predict_mocked.side_effect = ValueError("I raise a value error!")
|
120 |
+
|
121 |
+
response = client.post("/infer_samgis", json=event)
|
122 |
+
print("response.status_code:", response.status_code)
|
123 |
+
assert response.status_code == 500
|
124 |
+
body = response.json()
|
125 |
+
print("response.body:", body)
|
126 |
+
assert body == {'msg': 'Error - Internal Server Error'}
|
127 |
+
|
128 |
+
@patch.object(time, "time")
|
129 |
+
def test_infer_samgis_real_200(self, time_mocked):
|
130 |
+
import shapely
|
131 |
+
|
132 |
+
time_mocked.return_value = 0
|
133 |
+
|
134 |
+
response = client.post("/infer_samgis", json=event)
|
135 |
+
print("response.status_code:", response.status_code)
|
136 |
+
assert response.status_code == 200
|
137 |
+
body_string = response.json()["body"]
|
138 |
+
body_loaded = json.loads(body_string)
|
139 |
+
print("response.body_loaded:", body_loaded)
|
140 |
+
assert "duration_run" in body_loaded
|
141 |
+
output = body_loaded["output"]
|
142 |
+
assert 'n_predictions' in output
|
143 |
+
assert "n_shapes_geojson" in output
|
144 |
+
geojson = output["geojson"]
|
145 |
+
output_geojson = shapely.from_geojson(geojson)
|
146 |
+
print("output_geojson::", type(output_geojson))
|
147 |
+
assert isinstance(output_geojson, shapely.GeometryCollection)
|
148 |
+
assert len(output_geojson.geoms) == 3
|
149 |
+
|
150 |
+
@patch.object(time, "time")
|
151 |
+
@patch.object(fastapi_wrapper, "samexporter_predict")
|
152 |
+
def test_infer_samgis_mocked_200(self, samexporter_predict_mocked, time_mocked):
|
153 |
+
self.maxDiff = None
|
154 |
+
|
155 |
+
time_mocked.return_value = 0
|
156 |
+
samexporter_output = {
|
157 |
+
"n_predictions": 1,
|
158 |
+
"geojson": "{\"type\": \"FeatureCollection\", \"features\": [{\"id\": \"0\", \"type\": \"Feature\", " +
|
159 |
+
"\"properties\": {\"raster_val\": 255.0}, \"geometry\": {\"type\": \"Polygon\", " +
|
160 |
+
"\"coordinates\": [[[148.359375, -40.4469470596005], [148.447265625, -40.4469470596005], " +
|
161 |
+
"[148.447265625, -40.51379915504414], [148.359375, -40.4469470596005]]]}}]}",
|
162 |
+
"n_shapes_geojson": 2
|
163 |
+
}
|
164 |
+
samexporter_predict_mocked.return_value = samexporter_output
|
165 |
+
|
166 |
+
response = client.post("/infer_samgis", json=event)
|
167 |
+
print("response.status_code:", response.status_code)
|
168 |
+
assert response.status_code == 200
|
169 |
+
response_json = response.json()
|
170 |
+
body_loaded = json.loads(response_json["body"])
|
171 |
+
print("response.body_loaded:", body_loaded)
|
172 |
+
self.assertDictEqual(body_loaded, {'duration_run': 0, 'output': samexporter_output})
|
tests/{test_app.py → test_lambda_app.py}
RENAMED
@@ -3,14 +3,19 @@ import time
|
|
3 |
import unittest
|
4 |
from unittest.mock import patch
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
|
|
9 |
from wrappers import lambda_wrapper
|
10 |
from tests.local_tiles_http_server import LocalTilesHttpServer
|
11 |
|
12 |
|
13 |
-
class
|
14 |
@patch.object(time, "time")
|
15 |
@patch.object(lambda_wrapper, "samexporter_predict")
|
16 |
@patch.object(lambda_wrapper, "get_parsed_bbox_points")
|
@@ -141,7 +146,7 @@ class TestAppFailures(unittest.TestCase):
|
|
141 |
print(f"types: response_200:{type(response_200)}, expected:{type(expected_response_200)}.")
|
142 |
assert response_200 == expected_response_200
|
143 |
|
144 |
-
@patch.object(
|
145 |
def test_lambda_handler_200_real_single_multi_point(self, get_url_tile_mocked):
|
146 |
import xyzservices
|
147 |
import shapely
|
|
|
3 |
import unittest
|
4 |
from unittest.mock import patch
|
5 |
|
6 |
+
try:
|
7 |
+
from awslambdaric.lambda_context import LambdaContext
|
8 |
+
except ImportError:
|
9 |
+
import pip
|
10 |
+
pip.main(['install', 'awslambdaric'])
|
11 |
|
12 |
+
|
13 |
+
from samgis.io import wrappers_helpers
|
14 |
from wrappers import lambda_wrapper
|
15 |
from tests.local_tiles_http_server import LocalTilesHttpServer
|
16 |
|
17 |
|
18 |
+
class TestLambdaApp(unittest.TestCase):
|
19 |
@patch.object(time, "time")
|
20 |
@patch.object(lambda_wrapper, "samexporter_predict")
|
21 |
@patch.object(lambda_wrapper, "get_parsed_bbox_points")
|
|
|
146 |
print(f"types: response_200:{type(response_200)}, expected:{type(expected_response_200)}.")
|
147 |
assert response_200 == expected_response_200
|
148 |
|
149 |
+
@patch.object(wrappers_helpers, "get_url_tile")
|
150 |
def test_lambda_handler_200_real_single_multi_point(self, get_url_tile_mocked):
|
151 |
import xyzservices
|
152 |
import shapely
|
wrappers/fastapi_wrapper.py
CHANGED
@@ -5,11 +5,16 @@ from fastapi import FastAPI, HTTPException, Request, status
|
|
5 |
from fastapi.exceptions import RequestValidationError
|
6 |
from fastapi.responses import FileResponse, JSONResponse
|
7 |
from fastapi.staticfiles import StaticFiles
|
|
|
8 |
|
9 |
-
from samgis import
|
10 |
from samgis.io.wrappers_helpers import get_parsed_bbox_points
|
11 |
from samgis.utilities.type_hints import ApiRequestBody
|
|
|
|
|
12 |
|
|
|
|
|
13 |
app = FastAPI()
|
14 |
|
15 |
|
@@ -49,10 +54,7 @@ async def health() -> JSONResponse:
|
|
49 |
|
50 |
|
51 |
@app.post("/infer_samgis")
|
52 |
-
def
|
53 |
-
import subprocess
|
54 |
-
|
55 |
-
from samgis.prediction_api.predictors import samexporter_predict
|
56 |
app_logger.info("starting inference request...")
|
57 |
|
58 |
try:
|
@@ -74,15 +76,17 @@ def samgis(request_input: ApiRequestBody):
|
|
74 |
}
|
75 |
return JSONResponse(status_code=200, content={"body": json.dumps(body)})
|
76 |
except Exception as inference_exception:
|
|
|
77 |
home_content = subprocess.run(
|
78 |
"ls -l /var/task", shell=True, universal_newlines=True, stdout=subprocess.PIPE
|
79 |
)
|
80 |
app_logger.error(f"/home/user ls -l: {home_content.stdout}.")
|
81 |
app_logger.error(f"inference error:{inference_exception}.")
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
86 |
|
87 |
|
88 |
@app.exception_handler(RequestValidationError)
|
@@ -90,7 +94,7 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
|
|
90 |
app_logger.error(f"exception errors: {exc.errors()}.")
|
91 |
app_logger.error(f"exception body: {exc.body}.")
|
92 |
headers = request.headers.items()
|
93 |
-
app_logger.error(f'request header: {dict(headers)}.'
|
94 |
params = request.query_params.items()
|
95 |
app_logger.error(f'request query params: {dict(params)}.')
|
96 |
return JSONResponse(
|
@@ -103,7 +107,7 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
|
|
103 |
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
104 |
app_logger.error(f"exception: {str(exc)}.")
|
105 |
headers = request.headers.items()
|
106 |
-
app_logger.error(f'request header: {dict(headers)}.'
|
107 |
params = request.query_params.items()
|
108 |
app_logger.error(f'request query params: {dict(params)}.')
|
109 |
return JSONResponse(
|
@@ -113,7 +117,7 @@ async def http_exception_handler(request: Request, exc: HTTPException) -> JSONRe
|
|
113 |
|
114 |
|
115 |
# important: the index() function and the app.mount MUST be at the end
|
116 |
-
app.mount("/", StaticFiles(directory="static/dist", html=True), name="static")
|
117 |
|
118 |
|
119 |
@app.get("/")
|
|
|
5 |
from fastapi.exceptions import RequestValidationError
|
6 |
from fastapi.responses import FileResponse, JSONResponse
|
7 |
from fastapi.staticfiles import StaticFiles
|
8 |
+
from pydantic import ValidationError
|
9 |
|
10 |
+
from samgis import PROJECT_ROOT_FOLDER
|
11 |
from samgis.io.wrappers_helpers import get_parsed_bbox_points
|
12 |
from samgis.utilities.type_hints import ApiRequestBody
|
13 |
+
from samgis.utilities.fastapi_logger import setup_logging
|
14 |
+
from samgis.prediction_api.predictors import samexporter_predict
|
15 |
|
16 |
+
|
17 |
+
app_logger = setup_logging(debug=True)
|
18 |
app = FastAPI()
|
19 |
|
20 |
|
|
|
54 |
|
55 |
|
56 |
@app.post("/infer_samgis")
|
57 |
+
def infer_samgis(request_input: ApiRequestBody):
|
|
|
|
|
|
|
58 |
app_logger.info("starting inference request...")
|
59 |
|
60 |
try:
|
|
|
76 |
}
|
77 |
return JSONResponse(status_code=200, content={"body": json.dumps(body)})
|
78 |
except Exception as inference_exception:
|
79 |
+
import subprocess
|
80 |
home_content = subprocess.run(
|
81 |
"ls -l /var/task", shell=True, universal_newlines=True, stdout=subprocess.PIPE
|
82 |
)
|
83 |
app_logger.error(f"/home/user ls -l: {home_content.stdout}.")
|
84 |
app_logger.error(f"inference error:{inference_exception}.")
|
85 |
+
raise HTTPException(
|
86 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference")
|
87 |
+
except ValidationError as va1:
|
88 |
+
app_logger.error(f"validation error: {str(va1)}.")
|
89 |
+
raise ValidationError("Unprocessable Entity")
|
90 |
|
91 |
|
92 |
@app.exception_handler(RequestValidationError)
|
|
|
94 |
app_logger.error(f"exception errors: {exc.errors()}.")
|
95 |
app_logger.error(f"exception body: {exc.body}.")
|
96 |
headers = request.headers.items()
|
97 |
+
app_logger.error(f'request header: {dict(headers)}.')
|
98 |
params = request.query_params.items()
|
99 |
app_logger.error(f'request query params: {dict(params)}.')
|
100 |
return JSONResponse(
|
|
|
107 |
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
108 |
app_logger.error(f"exception: {str(exc)}.")
|
109 |
headers = request.headers.items()
|
110 |
+
app_logger.error(f'request header: {dict(headers)}.')
|
111 |
params = request.query_params.items()
|
112 |
app_logger.error(f'request query params: {dict(params)}.')
|
113 |
return JSONResponse(
|
|
|
117 |
|
118 |
|
119 |
# important: the index() function and the app.mount MUST be at the end
|
120 |
+
app.mount("/", StaticFiles(directory=PROJECT_ROOT_FOLDER / "static" / "dist", html=True), name="static")
|
121 |
|
122 |
|
123 |
@app.get("/")
|