text_summary / app1.py
bambadij's picture
'fixe'
8d23cfc
raw
history blame
2.95 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import logging
import os
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForCausalLM
# Configurer les répertoires de cache pour Transformers
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
os.environ['HF_HOME'] = '/app/.cache'
# Informations générales pour l'API
Informations = """
-text : Texte à resumer
output:
- Text summary : texte resumé
"""
app = FastAPI(
title='Text Summary',
description=Informations
)
# Configurer les logs
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Ajouter le middleware CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Prompt par défaut
DEFAULT_PROMPT = "Fais nous un résumé descriptif en français de la plainte suivante en 4 phrases concises, en vous concentrant sur les faits principaux et en évitant toute introduction générique. Nettoie également le texte si nécessaire : "
# Modèle de la requête
class RequestModel(BaseModel):
text: str
# Charger le modèle et le tokenizer
model_name = "models/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f"
tokenizer = AutoTokenizer.from_pretrained(model_name,token="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")
model = AutoModelForCausalLM.from_pretrained(model_name,token="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")
@app.get("/")
async def home():
return 'STN BIG DATA'
@app.post("/generate/")
async def generate_text(request: RequestModel):
try:
# Combiner le prompt par défaut et le texte de l'utilisateur
full_prompt = DEFAULT_PROMPT + request.text
# Tokeniser l'entrée
input_ids = tokenizer.encode(full_prompt, return_tensors="pt")
# Générer du texte avec le modèle
output = model.generate(input_ids, max_length=150, num_return_sequences=1)
# Décoder la sortie pour obtenir le texte généré
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Optionnel : nettoyage du texte généré pour enlever les phrases génériques
intro_phrases = [
"Voici un résumé de la plainte en 5 phrases :",
"Résumé :",
"Voici ce qui s'est passé :",
"Cette plainte a été déposée par"
]
for phrase in intro_phrases:
if generated_text.startswith(phrase):
generated_text = generated_text[len(phrase):].strip()
break
return {"summary_text": generated_text}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur inattendue : {str(e)}")
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8080, reload=True)