Spaces:
Running
Running
import json | |
import uuid | |
import uvicorn | |
from fastapi import FastAPI, HTTPException, Request, status | |
from fastapi.exceptions import RequestValidationError | |
from fastapi.responses import FileResponse, JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import ValidationError | |
from samgis import PROJECT_ROOT_FOLDER | |
from samgis.utilities.type_hints import ApiRequestBody, StringPromptApiRequestBody | |
from samgis_core.utilities.fastapi_logger import setup_logging | |
app_logger = setup_logging(debug=True) | |
app = FastAPI() | |
async def request_middleware(request, call_next): | |
request_id = str(uuid.uuid4()) | |
with app_logger.contextualize(request_id=request_id): | |
app_logger.info("Request started") | |
try: | |
response = await call_next(request) | |
except Exception as ex: | |
app_logger.error(f"Request failed: {ex}") | |
response = JSONResponse(content={"success": False}, status_code=500) | |
finally: | |
response.headers["X-Request-ID"] = request_id | |
app_logger.info("Request ended") | |
return response | |
def post_test_dictlist2(request_input: ApiRequestBody) -> JSONResponse: | |
from samgis.io.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt | |
request_body = get_parsed_bbox_points_with_dictlist_prompt(request_input) | |
app_logger.info(f"request_body:{request_body}.") | |
return JSONResponse( | |
status_code=200, | |
content=request_body | |
) | |
async def health() -> JSONResponse: | |
from samgis.__version__ import __version__ as version | |
from samgis_core.__version__ import __version__ as version_core | |
app_logger.info(f"still alive, version:{version}, core version:{version_core}.") | |
return JSONResponse(status_code=200, content={"msg": "still alive..."}) | |
def post_test_string(request_input: StringPromptApiRequestBody) -> JSONResponse: | |
from lisa_on_cuda.utils import app_helpers | |
from samgis.io.wrappers_helpers import get_parsed_bbox_points_with_string_prompt | |
request_body = get_parsed_bbox_points_with_string_prompt(request_input) | |
app_logger.info(f"request_body:{request_body}.") | |
custom_args = app_helpers.parse_args([]) | |
request_body["content"] = {**request_body, "precision": str(custom_args.precision)} | |
return JSONResponse( | |
status_code=200, | |
content=request_body | |
) | |
def infer_lisa(request_input: StringPromptApiRequestBody) -> JSONResponse: | |
from samgis.prediction_api import lisa | |
from samgis.io.wrappers_helpers import get_parsed_bbox_points_with_string_prompt | |
app_logger.info("starting lisa inference request...") | |
try: | |
import time | |
time_start_run = time.time() | |
body_request = get_parsed_bbox_points_with_string_prompt(request_input) | |
app_logger.info(f"lisa body_request:{body_request}.") | |
app_logger.info(f"lisa module:{lisa}.") | |
try: | |
output = lisa.lisa_predict( | |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"], | |
source=body_request["source"] | |
) | |
duration_run = time.time() - time_start_run | |
app_logger.info(f"duration_run:{duration_run}.") | |
body = { | |
"duration_run": duration_run, | |
"output": output | |
} | |
return JSONResponse(status_code=200, content={"body": json.dumps(body)}) | |
except Exception as inference_exception: | |
import subprocess | |
home_content = subprocess.run( | |
f"ls -l {PROJECT_ROOT_FOLDER}/", shell=True, universal_newlines=True, stdout=subprocess.PIPE | |
) | |
app_logger.error(f"'ls -l' command output: {home_content.stdout}.") | |
app_logger.error(f"inference error:{inference_exception}.") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference") | |
except ValidationError as va1: | |
app_logger.error(f"validation error: {str(va1)}.") | |
raise ValidationError("Unprocessable Entity") | |
def infer_samgis(request_input: ApiRequestBody) -> JSONResponse: | |
from samgis.prediction_api import predictors | |
from samgis.io.wrappers_helpers import get_parsed_bbox_points_with_dictlist_prompt | |
app_logger.info("starting samgis inference request...") | |
try: | |
import time | |
time_start_run = time.time() | |
body_request = get_parsed_bbox_points_with_dictlist_prompt(request_input) | |
app_logger.info(f"body_request:{body_request}.") | |
try: | |
output = predictors.samexporter_predict( | |
bbox=body_request["bbox"], prompt=body_request["prompt"], zoom=body_request["zoom"], | |
source=body_request["source"] | |
) | |
duration_run = time.time() - time_start_run | |
app_logger.info(f"duration_run:{duration_run}.") | |
body = { | |
"duration_run": duration_run, | |
"output": output | |
} | |
return JSONResponse(status_code=200, content={"body": json.dumps(body)}) | |
except Exception as inference_exception: | |
import subprocess | |
home_content = subprocess.run( | |
"ls -l /var/task", shell=True, universal_newlines=True, stdout=subprocess.PIPE | |
) | |
app_logger.error(f"/home/user ls -l: {home_content.stdout}.") | |
app_logger.error(f"inference error:{inference_exception}.") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error on inference") | |
except ValidationError as va1: | |
app_logger.error(f"validation error: {str(va1)}.") | |
raise ValidationError("Unprocessable Entity") | |
async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: | |
app_logger.error(f"exception errors: {exc.errors()}.") | |
app_logger.error(f"exception body: {exc.body}.") | |
headers = request.headers.items() | |
app_logger.error(f'request header: {dict(headers)}.') | |
params = request.query_params.items() | |
app_logger.error(f'request query params: {dict(params)}.') | |
return JSONResponse( | |
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
content={"msg": "Error - Unprocessable Entity"} | |
) | |
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: | |
app_logger.error(f"exception: {str(exc)}.") | |
headers = request.headers.items() | |
app_logger.error(f'request header: {dict(headers)}.') | |
params = request.query_params.items() | |
app_logger.error(f'request query params: {dict(params)}.') | |
return JSONResponse( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
content={"msg": "Error - Internal Server Error"} | |
) | |
# important: the index() function and the app.mount MUST be at the end | |
app.mount("/lisa", StaticFiles(directory=PROJECT_ROOT_FOLDER / "static" / "dist", html=True), name="lisa") | |
async def lisa() -> FileResponse: | |
return FileResponse(path=PROJECT_ROOT_FOLDER / "static" / "dist" / "lisa.html", media_type="text/html") | |
app.mount("/", StaticFiles(directory=PROJECT_ROOT_FOLDER / "static" / "dist", html=True), name="static") | |
async def index() -> FileResponse: | |
return FileResponse(path=PROJECT_ROOT_FOLDER / "static" / "dist" / "index.html", media_type="text/html") | |
if __name__ == '__main__': | |
try: | |
uvicorn.run(host="0.0.0.0", port=7860, app=app) | |
except Exception as e: | |
app_logger.error("e:", e) | |
raise e | |