Spaces:
Runtime error
Runtime error
import uvicorn | |
from fastapi import FastAPI, Form, Request | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
from inference import inferenceAPI | |
# from transformers import RobertaTokenizerFast, EncoderDecoderModel | |
# ------- MODELE HUGGING FACE QUI MARCHE BIEN ------------------------------------ | |
# device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# ckpt = 'mrm8488/camembert2camembert_shared-finetuned-french-summarization' | |
# tokenizer = RobertaTokenizerFast.from_pretrained(ckpt) | |
# model = EncoderDecoderModel.from_pretrained(ckpt).to(device) | |
# def generate_summary(text): | |
# inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt") | |
# input_ids = inputs.input_ids.to(device) | |
# attention_mask = inputs.attention_mask.to(device) | |
# output = model.generate(input_ids, attention_mask=attention_mask) | |
# return tokenizer.decode(output[0], skip_special_tokens=True) | |
# ---------------------------------------------------------------------------------- | |
# ------ NOTRE MODELE -------------------------------------------------------------- | |
# appel de la fonction inférence, adaptée pour une entrée txt | |
def summarize(text: str): | |
return " ".join(inferenceAPI(text)) | |
# ---------------------------------------------------------------------------------- | |
# -------- API --------------------------------------------------------------------- | |
app = FastAPI() | |
# static pour tout ce qui est css | |
templates = Jinja2Templates(directory="templates") | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
app.mount("/templates", StaticFiles(directory="templates"), name="templates") | |
async def index(request: Request): | |
return templates.TemplateResponse("index.html.jinja", {"request": request}) | |
# pour donner les predictions | |
async def prediction(request: Request, text: str = Form(...)): | |
summary = summarize(text) | |
return templates.TemplateResponse( | |
"index.html.jinja", {"request": request, "text": text, "summary": summary} | |
) | |
# ------------------------------------------------------------------------------------ | |
# pour lancer le serveur et le reload à chaque changement sauvegardé dans le repo | |
if __name__ == "__main__": | |
uvicorn.run("api:app", port=8000, reload=True) | |