salma-api / app.py
TymaaHammouda's picture
Update wojood_model path
5b9d21c
from fastapi import FastAPI
from huggingface_hub import snapshot_download
from huggingface_hub import hf_hub_download
import os
print("Version 1")
app = FastAPI()
def download_file_from_hf(repo_id, filename):
"""
Downloads a single file from a Hugging Face repo into ~/.sinatools
Args:
repo_id (str): Hugging Face repo id, e.g. "SinaLab/ArabGlossBERT"
filename (str): Path of the file inside the repo, e.g. "config.json"
Returns:
str: Absolute path to the downloaded file
"""
target_dir = os.path.expanduser("~/.sinatools")
os.makedirs(target_dir, exist_ok=True)
file_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=target_dir,
local_dir_use_symlinks=False
)
return file_path
def download_folder_from_hf(repo_id, folder_name):
"""
Downloads a folder from a Hugging Face model repo into ~/.sinatools
"""
target_dir = os.path.expanduser("~/.sinatools")
local_path = snapshot_download(
repo_id=repo_id,
allow_patterns=f"{folder_name}/**",
local_dir=target_dir,
local_dir_use_symlinks=False
)
return os.path.join(local_path, folder_name)
print("Start loading")
download_folder_from_hf("SinaLab/Wojood_model", "Wj27012000.tar")
download_folder_from_hf("SinaLab/ArabGlossBERT", "bert-base-arabertv02_22_May_2021_00h_allglosses_unused01")
download_folder_from_hf("SinaLab/ArabGlossBERT", "bert-base-arabertv02")
download_file_from_hf("SinaLab/ArabGlossBERT","one_gram.pickle")
download_file_from_hf("SinaLab/ArabGlossBERT","two_grams.pickle")
download_file_from_hf("SinaLab/ArabGlossBERT","three_grams.pickle")
download_file_from_hf("SinaLab/ArabGlossBERT","four_grams.pickle")
download_file_from_hf("SinaLab/ArabGlossBERT","five_grams.pickle")
download_file_from_hf("SinaLab/ALMA","lemmas_dic.pickle")
print("Finish loading")
from sinatools.wsd.disambiguator import disambiguate
from pydantic import BaseModel
from fastapi.responses import JSONResponse
class SALMARequest(BaseModel):
text: str
@app.post("/predict")
def predict(request: SALMARequest):
# Load tagger
text = request.text
print("Start disambiguate")
salma_output = disambiguate(text)
content = {
"resp": salma_output,
"statusText": "OK",
"statusCode": 0,
}
return JSONResponse(
content=content,
media_type="application/json",
status_code=200,
)