|
|
from fastapi import FastAPI |
|
|
from huggingface_hub import snapshot_download |
|
|
from huggingface_hub import hf_hub_download |
|
|
import os |
|
|
|
|
|
print("Version 1") |
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
def download_file_from_hf(repo_id, filename): |
|
|
""" |
|
|
Downloads a single file from a Hugging Face repo into ~/.sinatools |
|
|
|
|
|
Args: |
|
|
repo_id (str): Hugging Face repo id, e.g. "SinaLab/ArabGlossBERT" |
|
|
filename (str): Path of the file inside the repo, e.g. "config.json" |
|
|
|
|
|
Returns: |
|
|
str: Absolute path to the downloaded file |
|
|
""" |
|
|
target_dir = os.path.expanduser("~/.sinatools") |
|
|
os.makedirs(target_dir, exist_ok=True) |
|
|
|
|
|
file_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=filename, |
|
|
local_dir=target_dir, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
|
|
|
return file_path |
|
|
|
|
|
def download_folder_from_hf(repo_id, folder_name): |
|
|
""" |
|
|
Downloads a folder from a Hugging Face model repo into ~/.sinatools |
|
|
""" |
|
|
target_dir = os.path.expanduser("~/.sinatools") |
|
|
|
|
|
local_path = snapshot_download( |
|
|
repo_id=repo_id, |
|
|
allow_patterns=f"{folder_name}/**", |
|
|
local_dir=target_dir, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
|
|
|
return os.path.join(local_path, folder_name) |
|
|
|
|
|
|
|
|
print("Start loading") |
|
|
download_folder_from_hf("SinaLab/Wojood_model", "Wj27012000.tar") |
|
|
download_folder_from_hf("SinaLab/ArabGlossBERT", "bert-base-arabertv02_22_May_2021_00h_allglosses_unused01") |
|
|
download_folder_from_hf("SinaLab/ArabGlossBERT", "bert-base-arabertv02") |
|
|
download_file_from_hf("SinaLab/ArabGlossBERT","one_gram.pickle") |
|
|
download_file_from_hf("SinaLab/ArabGlossBERT","two_grams.pickle") |
|
|
download_file_from_hf("SinaLab/ArabGlossBERT","three_grams.pickle") |
|
|
download_file_from_hf("SinaLab/ArabGlossBERT","four_grams.pickle") |
|
|
download_file_from_hf("SinaLab/ArabGlossBERT","five_grams.pickle") |
|
|
download_file_from_hf("SinaLab/ALMA","lemmas_dic.pickle") |
|
|
|
|
|
print("Finish loading") |
|
|
|
|
|
from sinatools.wsd.disambiguator import disambiguate |
|
|
from pydantic import BaseModel |
|
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
class SALMARequest(BaseModel): |
|
|
text: str |
|
|
|
|
|
@app.post("/predict") |
|
|
def predict(request: SALMARequest): |
|
|
|
|
|
text = request.text |
|
|
|
|
|
print("Start disambiguate") |
|
|
salma_output = disambiguate(text) |
|
|
content = { |
|
|
"resp": salma_output, |
|
|
"statusText": "OK", |
|
|
"statusCode": 0, |
|
|
} |
|
|
|
|
|
return JSONResponse( |
|
|
content=content, |
|
|
media_type="application/json", |
|
|
status_code=200, |
|
|
) |
|
|
|