AI_Web_Dev1 / app.py
FatimaGr's picture
add
b198707 verified
from fastapi.staticfiles import StaticFiles
import re
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse
import os
from fastapi.middleware.cors import CORSMiddleware
import logging
import matplotlib
matplotlib.use("Agg") # Mode sans interface graphique
logging.basicConfig(level=logging.INFO)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Autorise toutes les origines (à sécuriser en prod)
allow_credentials=True,
allow_methods=["*"], # Autorise toutes les méthodes (GET, POST, etc.)
allow_headers=["*"], # Autorise tous les headers
)
# Charger le modèle Hugging Face
model_name = "Salesforce/codegen-350M-mono"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
VALID_PLOTS = {"histplot", "scatterplot", "barplot", "lineplot", "boxplot"}
@app.post("/generate_viz/")
async def generate_viz(file: UploadFile = File(...), query: str = Form(...)):
try:
if query not in VALID_PLOTS:
return {"error": f"Type de graphique invalide. Choisissez parmi : {', '.join(VALID_PLOTS)}"}
df = pd.read_excel(file.file)
numeric_cols = df.select_dtypes(include=["number"]).columns
if len(numeric_cols) < 2:
return {"error": "Le fichier doit contenir au moins deux colonnes numériques."}
x_col, y_col = numeric_cols[:2]
# Contraintes spécifiques pour éviter l'erreur avec histplot
if query == "histplot":
prompt_y = ""
else:
prompt_y = f', y="{y_col}"'
# Générer l'invite pour le modèle
prompt = f"""
### Génère uniquement du code Python fonctionnel pour tracer un {query} avec Matplotlib et Seaborn ###
# Contraintes :
# - Utilise 'df' sans recréer de nouvelles données
# - Axe X : '{x_col}'
# - Enregistre le graphique sous 'plot.png'
# - Ne génère que du code Python valide, sans texte explicatif
# Contraintes spécifiques pour sns.histplot :
# - N'inclut pas "y=" car histplot ne supporte qu'un axe
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(8,6))
sns.{query}(data=df, x="{x_col}"{prompt_y})
plt.savefig("plot.png")
plt.close()
"""
# Génération du code
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=120, pad_token_id=tokenizer.eos_token_id)
generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Nettoyage du code
generated_code = re.sub(r"(import matplotlib.pyplot as plt\nimport seaborn as sns\n)+", "import matplotlib.pyplot as plt\nimport seaborn as sns\n", generated_code)
if generated_code.strip().endswith("sns."):
generated_code = generated_code.rsplit("\n", 1)[0] # Supprime la dernière ligne incomplète
print("🔹 Code généré par l'IA :\n", generated_code)
# Vérification syntaxique avant exécution
try:
compile(generated_code, "<string>", "exec")
except SyntaxError as e:
return {"error": f"Erreur de syntaxe détectée : {e}\nCode généré :\n{generated_code}"}
# Vérification des données
print(df.head()) # Affiche les premières lignes du dataframe
print(df.dtypes) # Vérifie les types de colonnes
print(f"Colonne '{x_col}' - Valeurs uniques:", df[x_col].unique())
if df.empty or x_col not in df.columns or df[x_col].isnull().all():
return {"error": f"La colonne '{x_col}' est absente ou ne contient pas de données valides."}
# Exécution du code généré
exec_env = {"df": df, "plt": plt, "sns": sns, "pd": pd}
exec(generated_code, exec_env)
# Vérification de l'image générée
img_path = "plot.png"
if not os.path.exists(img_path):
return {"error": "Le fichier plot.png n'a pas été généré."}
if os.path.getsize(img_path) == 0:
return {"error": "Le fichier plot.png est vide."}
plt.close()
return FileResponse(img_path, media_type="image/png")
except Exception as e:
return {"error": f"Erreur lors de la génération du graphique : {str(e)}"}
# ✅ Déplace ici le montage des fichiers statiques
app.mount("/", StaticFiles(directory="static", html=True), name="static")
# Redirection vers index.html
@app.get("/")
async def root():
return RedirectResponse(url="/index.html")