SummaryProject / api.py
EveSa's picture
fix : src path fix
332c8a6
raw
history blame
No virus
2.35 kB
import uvicorn
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from src.inference import inferenceAPI
# from transformers import RobertaTokenizerFast, EncoderDecoderModel
# ------- MODELE HUGGING FACE QUI MARCHE BIEN ------------------------------------
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# ckpt = 'mrm8488/camembert2camembert_shared-finetuned-french-summarization'
# tokenizer = RobertaTokenizerFast.from_pretrained(ckpt)
# model = EncoderDecoderModel.from_pretrained(ckpt).to(device)
# def generate_summary(text):
# inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
# input_ids = inputs.input_ids.to(device)
# attention_mask = inputs.attention_mask.to(device)
# output = model.generate(input_ids, attention_mask=attention_mask)
# return tokenizer.decode(output[0], skip_special_tokens=True)
# ----------------------------------------------------------------------------------
# ------ NOTRE MODELE --------------------------------------------------------------
# appel de la fonction inférence, adaptée pour une entrée txt
def summarize(text: str):
return " ".join(inferenceAPI(text))
# ----------------------------------------------------------------------------------
# -------- API ---------------------------------------------------------------------
app = FastAPI()
# static pour tout ce qui est css
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
@app.get("/")
async def index(request: Request):
return templates.TemplateResponse("index.html.jinja", {"request": request})
# pour donner les predictions
@app.post("/")
async def prediction(request: Request, text: str = Form(...)):
summary = summarize(text)
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "text": text, "summary": summary}
)
# ------------------------------------------------------------------------------------
# pour lancer le serveur et le reload à chaque changement sauvegardé dans le repo
if __name__ == "__main__":
uvicorn.run("api:app", port=8000, reload=True)