File size: 2,347 Bytes
ad78747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import uvicorn
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates

from inference import inferenceAPI

# from transformers import RobertaTokenizerFast, EncoderDecoderModel

# ------- MODELE HUGGING FACE QUI MARCHE BIEN ------------------------------------
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# ckpt = 'mrm8488/camembert2camembert_shared-finetuned-french-summarization'
# tokenizer = RobertaTokenizerFast.from_pretrained(ckpt)
# model = EncoderDecoderModel.from_pretrained(ckpt).to(device)

# def generate_summary(text):
#    inputs = tokenizer([text], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
#    input_ids = inputs.input_ids.to(device)
#    attention_mask = inputs.attention_mask.to(device)
#    output = model.generate(input_ids, attention_mask=attention_mask)
#    return tokenizer.decode(output[0], skip_special_tokens=True)
# ----------------------------------------------------------------------------------


# ------ NOTRE MODELE --------------------------------------------------------------
# appel de la fonction inférence, adaptée pour une entrée txt
def summarize(text: str):
    return " ".join(inferenceAPI(text))


# ----------------------------------------------------------------------------------

# -------- API ---------------------------------------------------------------------
app = FastAPI()

# static pour tout ce qui est css
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")


@app.get("/")
async def index(request: Request):
    return templates.TemplateResponse("index.html.jinja", {"request": request})


# pour donner les predictions
@app.post("/")
async def prediction(request: Request, text: str = Form(...)):
    summary = summarize(text)
    return templates.TemplateResponse(
        "index.html.jinja", {"request": request, "text": text, "summary": summary}
    )


# ------------------------------------------------------------------------------------


# pour lancer le serveur et le reload à chaque changement sauvegardé dans le repo
if __name__ == "__main__":
    uvicorn.run("api:app", port=8000, reload=True)