mrmft's picture
adding project source
4da642e
raw
history blame
No virus
1.97 kB
import uvicorn
import os
from typing import Union
from fastapi import FastAPI
from kpe import KPE
from fastapi.middleware.cors import CORSMiddleware
# from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi import APIRouter , Query
from sentence_transformers import SentenceTransformer
import utils
from ranker import get_sorted_keywords
from pydantic import BaseModel
app = FastAPI(
title="AHD Persian KPE",
# version=config.settings.VERSION,
description="Keyphrase Extraction",
openapi_url="/openapi.json",
docs_url="/",
)
TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
# Sets all CORS enabled origins
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], #str(origin) for origin in config.settings.BACKEND_CORS_ORIGINS
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class KpeParams(BaseModel):
text:str
count:int=10000
using_ner:bool=True
return_sorted:bool=False
router = APIRouter()
@router.get("/")
def home():
return "Welcome to AHD Keyphrase Extraction Service"
@router.post("/extract", description="extract keyphrase from persian documents")
async def extract(kpe_params: KpeParams):
global kpe
text = utils.normalize(kpe_params.text)
kps = kpe.extract(text, using_ner=kpe_params.using_ner)
if kpe_params.return_sorted:
kps = get_sorted_keywords(ranker_transformer, text, kps)
else:
kps = [(kp, 1) for kp in kps]
if len(kps) > kpe_params.count:
kps = kps[:kpe_params.count]
return kps
app.include_router(router)
if __name__ == "__main__":
uvicorn.run("main:app",host="0.0.0.0", port=7201)