Spaces:
Sleeping
Sleeping
File size: 2,116 Bytes
280d87f 5be68bc d2ec6cb 5be68bc 280d87f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os
import uvicorn
from Destinations.get_destinations import (get_destinations_list,
get_question_vector)
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from Model.model_predict_onnx import onnx_predictor
router = APIRouter(prefix="/model", tags=["Model"])
@router.get("/get_question_tags/{question}")
async def get_question_tags(question: str):
# Get the prediction
original_sentence, predicted_tags = onnx_predictor.predict(question)
# Print the sentence and its predicted tags
print("Sentence:", original_sentence)
print("Predicted Tags:", predicted_tags)
return {"question_tags": predicted_tags}
@router.get("/get_destinations_list/{question_tags}/{top_k}")
async def get_destinations_list_api(question_tags: str, top_k:str):
# Get the prediction
question_vector = get_question_vector(question_tags)
destinations_list = get_destinations_list(question_vector, int(top_k))
print("destinations_list:", destinations_list)
return {"destinations_list": destinations_list}
@router.get("/get_destinations_list_by_question/{question}/{top_k}")
async def get_destinations_list_api(question: str, top_k:str):
# Get the prediction
original_sentence, question_tags = onnx_predictor.predict(question)
# Print the sentence and its predicted tags
print("Sentence:", original_sentence)
print("Predicted Tags:", question_tags)
# Get the prediction
question_tags = " ".join(question_tags)
question_vector = get_question_vector(question_tags)
destinations_list = get_destinations_list(question_vector, int(top_k))
print("destinations_list:", destinations_list)
return {"destinations_list": destinations_list}
app = FastAPI(docs_url="/")
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
expose_headers=['*',]
)
app.include_router(router)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
|