bluenevus commited on
Commit
4fc79a4
·
verified ·
1 Parent(s): 1f5fb09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -128
app.py CHANGED
@@ -1,133 +1,109 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException, Form
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import HTMLResponse
4
- from fastapi.staticfiles import StaticFiles
5
  import pandas as pd
6
  import matplotlib.pyplot as plt
7
- import seaborn as sns
8
- import os
9
- import logging
10
- from huggingface_hub import InferenceClient
11
- from dotenv import load_dotenv
12
- import hashlib
13
-
14
- # Set up logging
15
- logging.basicConfig(level=logging.INFO)
16
- logger = logging.getLogger(__name__)
17
-
18
- # Load environment variables
19
- load_dotenv()
20
-
21
- app = FastAPI()
22
-
23
- app.add_middleware(
24
- CORSMiddleware,
25
- allow_origins=["*"],
26
- allow_credentials=True,
27
- allow_methods=["*"],
28
- allow_headers=["*"],
29
- )
30
-
31
- app.mount("/static", StaticFiles(directory="static"), name="static")
32
-
33
- API_TOKEN = os.getenv("HF_TOKEN")
34
- if not API_TOKEN:
35
- raise ValueError("HF_TOKEN environment variable not set.")
36
-
37
- MODEL_NAME = "gemini-2.5-pro-preview-03-25"
38
- client = InferenceClient(model=MODEL_NAME, token=API_TOKEN)
39
-
40
- UPLOAD_DIR = "uploads"
41
- os.makedirs(UPLOAD_DIR, exist_ok=True)
42
-
43
- IMAGES_DIR = os.path.join("static", "images")
44
- os.makedirs(IMAGES_DIR, exist_ok=True)
45
-
46
- @app.post("/upload/")
47
- async def upload_file(file: UploadFile = File(...)):
48
- if not file.filename.endswith((".xlsx", ".csv")):
49
- raise HTTPException(status_code=400, detail="File must be an Excel (.xlsx) or CSV file")
50
-
51
- file_path = os.path.join(UPLOAD_DIR, file.filename)
52
- with open(file_path, "wb") as buffer:
53
- buffer.write(await file.read())
54
-
55
- logger.info(f"File uploaded: {file.filename}")
56
- return {"filename": file.filename}
57
-
58
- @app.post("/generate-visualization/")
59
- async def generate_visualization(prompt: str = Form(...), filename: str = Form(...)):
60
- file_path = os.path.join(UPLOAD_DIR, filename)
61
-
62
- if not os.path.exists(file_path):
63
- raise HTTPException(status_code=404, detail="File not found on server.")
64
-
65
- try:
66
- if filename.endswith('.csv'):
67
- df = pd.read_csv(file_path)
68
- else:
69
- df = pd.read_excel(file_path)
70
- if df.empty:
71
- raise ValueError("File is empty.")
72
- except Exception as e:
73
- raise HTTPException(status_code=400, detail=f"Error reading file: {str(e)}")
74
-
75
- input_text = f"""
76
- Given the DataFrame 'df' with columns {', '.join(df.columns)} and preview:
77
- {df.head().to_string()}
78
- Write Python code to: {prompt}
79
- - Use ONLY 'df' (no external data loading).
80
- - Use pandas (pd), matplotlib.pyplot (plt), or seaborn (sns).
81
- - Include axis labels and a title.
82
- - Output ONLY executable code (no comments, functions, print, or triple quotes).
83
  """
84
 
85
- try:
86
- generated_code = client.text_generation(input_text, max_new_tokens=500)
87
- logger.info(f"Generated code:\n{generated_code}")
88
- except Exception as e:
89
- raise HTTPException(status_code=500, detail=f"Error querying model: {str(e)}")
90
-
91
- if not generated_code.strip():
92
- raise HTTPException(status_code=500, detail="No code generated by the AI model.")
93
-
94
- generated_code = generated_code.strip()
95
- if generated_code.startswith('"""') or generated_code.startswith("'''"):
96
- generated_code = generated_code.split('"""')[1] if '"""' in generated_code else generated_code.split("'''")[1]
97
- if generated_code.endswith('"""') or generated_code.endswith("'''"):
98
- generated_code = generated_code.rsplit('"""')[0] if '"""' in generated_code else generated_code.rsplit("'''")[0]
99
- generated_code = generated_code.strip()
100
-
101
- lines = generated_code.splitlines()
102
- executable_code = "\n".join(
103
- line.strip() for line in lines
104
- if line.strip() and not line.strip().startswith(('#', 'def', 'class', '"""', "'''"))
105
- and not any(kw in line for kw in ["pd.read_csv", "pd.read_excel", "http", "raise", "print"])
106
- ).strip()
107
-
108
- executable_code = executable_code.replace("plt.show()", "").strip()
109
-
110
- logger.info(f"Executable code:\n{executable_code}")
111
-
112
- plot_hash = hashlib.md5(f"{filename}_{prompt}".encode()).hexdigest()[:8]
113
- plot_filename = f"plot_{plot_hash}.png"
114
- plot_path = os.path.join(IMAGES_DIR, plot_filename)
115
-
116
- try:
117
- exec_globals = {"pd": pd, "plt": plt, "sns": sns, "df": df}
118
- exec(executable_code, exec_globals)
119
- plt.savefig(plot_path, bbox_inches="tight")
120
  plt.close()
121
- except Exception as e:
122
- logger.error(f"Error executing code:\n{executable_code}\nException: {str(e)}")
123
- raise HTTPException(status_code=500, detail=f"Error executing code: {str(e)}")
124
-
125
- if not os.path.exists(plot_path):
126
- raise HTTPException(status_code=500, detail="Plot file was not created.")
127
-
128
- return {"plot_url": f"/static/images/{plot_filename}", "generated_code": generated_code}
129
-
130
- @app.get("/")
131
- async def serve_frontend():
132
- with open("static/index.html", "r") as f:
133
- return HTMLResponse(content=f.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
 
 
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
+ import io
5
+ import base64
6
+ import google.generativeai as genai
7
+
8
+ def process_file(api_key, file, instructions):
9
+ # Set up Gemini API
10
+ genai.configure(api_key=api_key)
11
+ model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
12
+
13
+ # Read the file
14
+ if file.name.endswith('.csv'):
15
+ df = pd.read_csv(file.name)
16
+ else:
17
+ df = pd.read_excel(file.name)
18
+
19
+ # Analyze data and get visualization suggestions from Gemini
20
+ data_description = df.describe().to_string()
21
+ columns_info = "\n".join([f"{col}: {df[col].dtype}" for col in df.columns])
22
+ prompt = f"""
23
+ Given this dataset:
24
+ Columns and types:
25
+ {columns_info}
26
+
27
+ Data summary:
28
+ {data_description}
29
+
30
+ User instructions: {instructions if instructions else 'No specific instructions provided.'}
31
+
32
+ Suggest 3 ways to visualize this data. For each visualization:
33
+ 1. Describe the visualization type and what it will show.
34
+ 2. Provide Python code using matplotlib to create the visualization.
35
+ 3. Explain why this visualization is useful for understanding the data.
36
+
37
+ Format your response as:
38
+ Visualization 1:
39
+ Description: ...
40
+ Code: ...
41
+ Explanation: ...
42
+
43
+ Visualization 2:
44
+ ...
45
+
46
+ Visualization 3:
47
+ ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  """
49
 
50
+ response = model.generate_content(prompt)
51
+ suggestions = response.text.split("Visualization")
52
+
53
+ visualizations = []
54
+ for i, suggestion in enumerate(suggestions[1:4], 1): # Process only the first 3 visualizations
55
+ parts = suggestion.split("Code:")
56
+ description = parts[0].strip()
57
+ code = parts[1].split("Explanation:")[0].strip()
58
+
59
+ # Execute the code
60
+ plt.figure(figsize=(10, 6))
61
+ exec(code)
62
+ plt.title(f"Visualization {i}")
63
+
64
+ # Save the plot to a BytesIO object
65
+ buf = io.BytesIO()
66
+ plt.savefig(buf, format='png')
67
+ buf.seek(0)
68
+ img_str = base64.b64encode(buf.getvalue()).decode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  plt.close()
70
+
71
+ visualizations.append((f"data:image/png;base64,{img_str}", description, code))
72
+
73
+ return visualizations
74
+
75
+ # Gradio interface
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("# Data Visualization with Gemini")
78
+ api_key = gr.Textbox(label="Enter Gemini API Key", type="password")
79
+ file = gr.File(label="Upload Excel or CSV file")
80
+ instructions = gr.Textbox(label="Optional visualization instructions")
81
+ submit = gr.Button("Generate Visualizations")
82
+
83
+ with gr.Row():
84
+ output1 = gr.Image(label="Visualization 1")
85
+ output2 = gr.Image(label="Visualization 2")
86
+ output3 = gr.Image(label="Visualization 3")
87
+
88
+ with gr.Row():
89
+ desc1 = gr.Textbox(label="Description 1")
90
+ desc2 = gr.Textbox(label="Description 2")
91
+ desc3 = gr.Textbox(label="Description 3")
92
+
93
+ with gr.Row():
94
+ code1 = gr.Code(language="python", label="Code 1")
95
+ code2 = gr.Code(language="python", label="Code 2")
96
+ code3 = gr.Code(language="python", label="Code 3")
97
+
98
+ submit.click(
99
+ fn=process_file,
100
+ inputs=[api_key, file, instructions],
101
+ outputs=[
102
+ output1, desc1, code1,
103
+ output2, desc2, code2,
104
+ output3, desc3, code3
105
+ ],
106
+ show_progress=True,
107
+ )
108
+
109
+ demo.launch()