|
from fastapi import FastAPI, File, UploadFile, HTTPException, Form |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import os |
|
import logging |
|
from huggingface_hub import InferenceClient |
|
from dotenv import load_dotenv |
|
import hashlib |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
load_dotenv() |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
API_TOKEN = os.getenv("HF_TOKEN") |
|
if not API_TOKEN: |
|
raise ValueError("HF_TOKEN environment variable not set.") |
|
|
|
MODEL_NAME = "gemini-2.5-pro-preview-03-25" |
|
client = InferenceClient(model=MODEL_NAME, token=API_TOKEN) |
|
|
|
UPLOAD_DIR = "uploads" |
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
|
IMAGES_DIR = os.path.join("static", "images") |
|
os.makedirs(IMAGES_DIR, exist_ok=True) |
|
|
|
@app.post("/upload/") |
|
async def upload_file(file: UploadFile = File(...)): |
|
if not file.filename.endswith((".xlsx", ".csv")): |
|
raise HTTPException(status_code=400, detail="File must be an Excel (.xlsx) or CSV file") |
|
|
|
file_path = os.path.join(UPLOAD_DIR, file.filename) |
|
with open(file_path, "wb") as buffer: |
|
buffer.write(await file.read()) |
|
|
|
logger.info(f"File uploaded: {file.filename}") |
|
return {"filename": file.filename} |
|
|
|
@app.post("/generate-visualization/") |
|
async def generate_visualization(prompt: str = Form(...), filename: str = Form(...)): |
|
file_path = os.path.join(UPLOAD_DIR, filename) |
|
|
|
if not os.path.exists(file_path): |
|
raise HTTPException(status_code=404, detail="File not found on server.") |
|
|
|
try: |
|
if filename.endswith('.csv'): |
|
df = pd.read_csv(file_path) |
|
else: |
|
df = pd.read_excel(file_path) |
|
if df.empty: |
|
raise ValueError("File is empty.") |
|
except Exception as e: |
|
raise HTTPException(status_code=400, detail=f"Error reading file: {str(e)}") |
|
|
|
input_text = f""" |
|
Given the DataFrame 'df' with columns {', '.join(df.columns)} and preview: |
|
{df.head().to_string()} |
|
Write Python code to: {prompt} |
|
- Use ONLY 'df' (no external data loading). |
|
- Use pandas (pd), matplotlib.pyplot (plt), or seaborn (sns). |
|
- Include axis labels and a title. |
|
- Output ONLY executable code (no comments, functions, print, or triple quotes). |
|
""" |
|
|
|
try: |
|
generated_code = client.text_generation(input_text, max_new_tokens=500) |
|
logger.info(f"Generated code:\n{generated_code}") |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error querying model: {str(e)}") |
|
|
|
if not generated_code.strip(): |
|
raise HTTPException(status_code=500, detail="No code generated by the AI model.") |
|
|
|
generated_code = generated_code.strip() |
|
if generated_code.startswith('"""') or generated_code.startswith("'''"): |
|
generated_code = generated_code.split('"""')[1] if '"""' in generated_code else generated_code.split("'''")[1] |
|
if generated_code.endswith('"""') or generated_code.endswith("'''"): |
|
generated_code = generated_code.rsplit('"""')[0] if '"""' in generated_code else generated_code.rsplit("'''")[0] |
|
generated_code = generated_code.strip() |
|
|
|
lines = generated_code.splitlines() |
|
executable_code = "\n".join( |
|
line.strip() for line in lines |
|
if line.strip() and not line.strip().startswith(('#', 'def', 'class', '"""', "'''")) |
|
and not any(kw in line for kw in ["pd.read_csv", "pd.read_excel", "http", "raise", "print"]) |
|
).strip() |
|
|
|
executable_code = executable_code.replace("plt.show()", "").strip() |
|
|
|
logger.info(f"Executable code:\n{executable_code}") |
|
|
|
plot_hash = hashlib.md5(f"{filename}_{prompt}".encode()).hexdigest()[:8] |
|
plot_filename = f"plot_{plot_hash}.png" |
|
plot_path = os.path.join(IMAGES_DIR, plot_filename) |
|
|
|
try: |
|
exec_globals = {"pd": pd, "plt": plt, "sns": sns, "df": df} |
|
exec(executable_code, exec_globals) |
|
plt.savefig(plot_path, bbox_inches="tight") |
|
plt.close() |
|
except Exception as e: |
|
logger.error(f"Error executing code:\n{executable_code}\nException: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error executing code: {str(e)}") |
|
|
|
if not os.path.exists(plot_path): |
|
raise HTTPException(status_code=500, detail="Plot file was not created.") |
|
|
|
return {"plot_url": f"/static/images/{plot_filename}", "generated_code": generated_code} |
|
|
|
@app.get("/") |
|
async def serve_frontend(): |
|
with open("static/index.html", "r") as f: |
|
return HTMLResponse(content=f.read()) |