File size: 4,864 Bytes
7d1e58a
 
 
 
ff768e2
 
7d1e58a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfc4e12
7d1e58a
 
 
 
 
 
 
 
fd0a8c2
7d1e58a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import logging
from huggingface_hub import InferenceClient
from dotenv import load_dotenv
import hashlib

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.mount("/static", StaticFiles(directory="static"), name="static")

API_TOKEN = os.getenv("HF_TOKEN")
if not API_TOKEN:
    raise ValueError("HF_TOKEN environment variable not set.")

MODEL_NAME = "gemini-2.5-pro-preview-03-25"
client = InferenceClient(model=MODEL_NAME, token=API_TOKEN)

UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)

IMAGES_DIR = os.path.join("static", "images")
os.makedirs(IMAGES_DIR, exist_ok=True)

@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
    if not file.filename.endswith((".xlsx", ".csv")):
        raise HTTPException(status_code=400, detail="File must be an Excel (.xlsx) or CSV file")

    file_path = os.path.join(UPLOAD_DIR, file.filename)
    with open(file_path, "wb") as buffer:
        buffer.write(await file.read())

    logger.info(f"File uploaded: {file.filename}")
    return {"filename": file.filename}

@app.post("/generate-visualization/")
async def generate_visualization(prompt: str = Form(...), filename: str = Form(...)):
    file_path = os.path.join(UPLOAD_DIR, filename)

    if not os.path.exists(file_path):
        raise HTTPException(status_code=404, detail="File not found on server.")

    try:
        if filename.endswith('.csv'):
            df = pd.read_csv(file_path)
        else:
            df = pd.read_excel(file_path)
        if df.empty:
            raise ValueError("File is empty.")
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error reading file: {str(e)}")

    input_text = f"""
    Given the DataFrame 'df' with columns {', '.join(df.columns)} and preview:
    {df.head().to_string()}
    Write Python code to: {prompt}
    - Use ONLY 'df' (no external data loading).
    - Use pandas (pd), matplotlib.pyplot (plt), or seaborn (sns).
    - Include axis labels and a title.
    - Output ONLY executable code (no comments, functions, print, or triple quotes).
    """

    try:
        generated_code = client.text_generation(input_text, max_new_tokens=500)
        logger.info(f"Generated code:\n{generated_code}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error querying model: {str(e)}")

    if not generated_code.strip():
        raise HTTPException(status_code=500, detail="No code generated by the AI model.")

    generated_code = generated_code.strip()
    if generated_code.startswith('"""') or generated_code.startswith("'''"):
        generated_code = generated_code.split('"""')[1] if '"""' in generated_code else generated_code.split("'''")[1]
    if generated_code.endswith('"""') or generated_code.endswith("'''"):
        generated_code = generated_code.rsplit('"""')[0] if '"""' in generated_code else generated_code.rsplit("'''")[0]
    generated_code = generated_code.strip()

    lines = generated_code.splitlines()
    executable_code = "\n".join(
        line.strip() for line in lines
        if line.strip() and not line.strip().startswith(('#', 'def', 'class', '"""', "'''"))
        and not any(kw in line for kw in ["pd.read_csv", "pd.read_excel", "http", "raise", "print"])
    ).strip()

    executable_code = executable_code.replace("plt.show()", "").strip()

    logger.info(f"Executable code:\n{executable_code}")

    plot_hash = hashlib.md5(f"{filename}_{prompt}".encode()).hexdigest()[:8]
    plot_filename = f"plot_{plot_hash}.png"
    plot_path = os.path.join(IMAGES_DIR, plot_filename)

    try:
        exec_globals = {"pd": pd, "plt": plt, "sns": sns, "df": df}
        exec(executable_code, exec_globals)
        plt.savefig(plot_path, bbox_inches="tight")
        plt.close()
    except Exception as e:
        logger.error(f"Error executing code:\n{executable_code}\nException: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error executing code: {str(e)}")

    if not os.path.exists(plot_path):
        raise HTTPException(status_code=500, detail="Plot file was not created.")

    return {"plot_url": f"/static/images/{plot_filename}", "generated_code": generated_code}

@app.get("/")
async def serve_frontend():
    with open("static/index.html", "r") as f:
        return HTMLResponse(content=f.read())