import uvicorn from fastapi import FastAPI, Form, Request from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from inference import inferenceAPI # from transformers import RobertaTokenizerFast, EncoderDecoderModel # ------- MODELE HUGGING FACE QUI MARCHE BIEN ------------------------------------ # device = 'cuda' if torch.cuda.is_available() else 'cpu' # ckpt = 'mrm8488/camembert2camembert_shared-finetuned-french-summarization' # tokenizer = RobertaTokenizerFast.from_pretrained(ckpt) # model = EncoderDecoderModel.from_pretrained(ckpt).to(device) # def generate_summary(text): # inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt") # input_ids = inputs.input_ids.to(device) # attention_mask = inputs.attention_mask.to(device) # output = model.generate(input_ids, attention_mask=attention_mask) # return tokenizer.decode(output[0], skip_special_tokens=True) # ---------------------------------------------------------------------------------- # ------ NOTRE MODELE -------------------------------------------------------------- # appel de la fonction inférence, adaptée pour une entrée txt def summarize(text: str): return " ".join(inferenceAPI(text)) # ---------------------------------------------------------------------------------- # -------- API --------------------------------------------------------------------- app = FastAPI() # static pour tout ce qui est css templates = Jinja2Templates(directory="templates") app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/templates", StaticFiles(directory="templates"), name="templates") @app.get("/") async def index(request: Request): return templates.TemplateResponse("index.html.jinja", {"request": request}) # pour donner les predictions @app.post("/") async def prediction(request: Request, text: str = Form(...)): summary = summarize(text) return templates.TemplateResponse( "index.html.jinja", {"request": request, "text": text, "summary": summary} ) # ------------------------------------------------------------------------------------ # pour lancer le serveur et le reload à chaque changement sauvegardé dans le repo if __name__ == "__main__": uvicorn.run("api:app", port=8000, reload=True)