text_prompt / app.py
bambadij's picture
fixe
52d4394
from fastapi import FastAPI, HTTPException, status, UploadFile, File
from pydantic import BaseModel
import uvicorn
import logging
import os
import requests
from fastapi.middleware.cors import CORSMiddleware
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
os.environ['HF_HOME'] = '/app/.cache'
Informations = """
-text : Texte à résumer
output:
- Text summary : texte résumé
"""
app = FastAPI(
title='Text Summary',
description=Informations
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEFAULT_PROMPT = "Résumez la plainte suivante en 5 phrases concises, en vous concentrant sur les faits principaux et en évitant toute introduction générique : "
class TextSummary(BaseModel):
prompt: str
class RequestModel(BaseModel):
text: str
OLLAMA_URL = "http://localhost:11434" # URL d'Ollama dans le conteneur
@app.get("/")
async def home():
return 'STN BIG DATA'
@app.post("/generate/")
async def generate_text(request: RequestModel):
try:
full_prompt = DEFAULT_PROMPT + request.text
response = requests.post(f"{OLLAMA_URL}/api/generate", json={
"prompt": full_prompt,
"stream": False,
"model": "llama3"
})
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response)
generated_text = response.json().get('response', '')
intro_phrases = [
"Voici un résumé de la plainte en 5 phrases :",
"Résumé :",
"Voici ce qui s'est passé :",
"Cette plainte a été déposée par"
]
for phrase in intro_phrases:
if generated_text.startswith(phrase):
generated_text = generated_text[len(phrase):].strip()
break
return {"summary_text_2": generated_text}
except requests.RequestException as e:
raise HTTPException(status_code=500, detail=f"Erreur de requête : {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur inattendue : {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)