Mirror / src /app /api_backend.py
Spico's picture
update tb and logger
41eec42
import traceback
from typing import Any, Dict, List
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from rex.utils.initialization import set_seed_and_log_path
from rex.utils.logging import logger
from src.task import SchemaGuidedInstructBertTask
set_seed_and_log_path(log_path="debug.log")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class RequestData(BaseModel):
data: List[Dict[str, Any]]
task = SchemaGuidedInstructBertTask.from_taskdir(
"mirror_outputs/Mirror_Pretrain_AllExcluded_2",
load_best_model=True,
initialize=False,
dump_configfile=False,
update_config={
"regenerate_cache": False,
},
)
@app.post("/process")
def process_data(data: RequestData):
input_data = data.data
ok = True
msg = ""
results = {}
try:
results = task.predict(input_data)
msg = "success"
except KeyboardInterrupt:
raise KeyboardInterrupt
except Exception:
ok = False
msg = traceback.format_exc()
# Return the processed data
logger.info(f"Data: {input_data}, Prediction: {results}")
return {"ok": ok, "msg": msg, "results": results}
@app.get("/")
async def api():
return FileResponse("./index.html", media_type="text/html")
if __name__ == "__main__":
log_config = uvicorn.config.LOGGING_CONFIG
log_config["formatters"]["access"]["fmt"] = (
"%(asctime)s | " + log_config["formatters"]["access"]["fmt"]
)
log_config["formatters"]["default"]["fmt"] = (
"%(asctime)s | " + log_config["formatters"]["default"]["fmt"]
)
uvicorn.run(
"src.app.api_backend:app",
host="0.0.0.0",
port=7860,
log_level="debug",
log_config=log_config,
reload=True,
)