from fastapi import FastAPI, HTTPException import huggingface_hub import torch from fairseq.models.bart import BARTModel import os import re EXPECTED_TOKEN = os.environ.get("EXPECTED_TOKEN") REPO_NAME = os.environ.get("REPO_NAME") REPO_NAME_HUGG = os.environ.get("REPO_NAME_HUGG") app = FastAPI() path_model = huggingface_hub.hf_hub_download(repo_id=f'{REPO_NAME}/{REPO_NAME_HUGG}' , filename='model/checkpoint_best.pt', token=True) files = ['dict.src.txt', 'dict.tgt.txt', 'preprocess.log', 'train.src-tgt.src.bin', 'train.src-tgt.src.idx', 'train.src-tgt.tgt.bin', 'train.src-tgt.tgt.idx', 'valid.src-tgt.src.bin', 'valid.src-tgt.src.idx', 'valid.src-tgt.tgt.bin', 'valid.src-tgt.tgt.idx'] for file in files: path_data = huggingface_hub.hf_hub_download(repo_id=f'{REPO_NAME}/{REPO_NAME_HUGG}' , filename=file, subfolder='gec_data-bin_ptbr', token=True) bart = BARTModel.from_pretrained( '/'.join(path_model.split('/')[:-1]), checkpoint_file='checkpoint_best.pt', data_name_or_path='/'.join(path_data.split('/')[:-1]) ) bart.eval() def posprocessing(frase: str): nova_frase = re.sub(r'\s+([.,!?;:])', r'\1', frase) return nova_frase @app.post('/inference') async def predict(frase: str, token: str): if token != EXPECTED_TOKEN: raise HTTPException(status_code=401, detail="Token inválido") with torch.no_grad(): result = bart.sample([frase], beam=1) return posprocessing(result[0])