File size: 3,054 Bytes
4fc79a4
ff768e2
 
4fc79a4
 
29f87f1
4fc79a4
 
5fcb5c4
69a028d
12598b6
4fc79a4
5fcb5c4
 
 
 
 
69a028d
5fcb5c4
 
69a028d
5fcb5c4
29f87f1
69a028d
5fcb5c4
 
 
 
 
 
 
 
 
 
4fc79a4
5fcb5c4
a888f55
5fcb5c4
a888f55
5fcb5c4
a888f55
 
5fcb5c4
 
 
 
 
a888f55
 
 
 
 
 
 
5fcb5c4
 
 
 
 
 
 
 
a888f55
 
 
5fcb5c4
a888f55
5fcb5c4
 
a888f55
5fcb5c4
a888f55
5fcb5c4
a888f55
 
 
5fcb5c4
 
a888f55
5fcb5c4
 
a888f55
 
5fcb5c4
 
a888f55
 
5fcb5c4
 
a888f55
5fcb5c4
 
a888f55
 
5fcb5c4
a888f55
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image

def process_file(api_key, file, instructions):
    # Configure Gemini with precise model version
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')

    # Read file
    if file.name.endswith('.csv'):
        df = pd.read_csv(file.name)
    else:
        df = pd.read_excel(file.name)

    # Enhanced prompt with strict code requirements
    prompt = f"""Generate 3 matplotlib visualization codes for this data:
    Columns: {list(df.columns)}
    First 3 rows: {df.head(3).to_dict()}
    
    Requirements:
    1. Each visualization must start with:
       plt.figure(figsize=(16,9), dpi=120)
       plt.style.use('seaborn')
    2. Include complete plotting code with:
       - Title
       - Axis labels
       - Legend if needed
       - plt.tight_layout()
    3. Different chart types (bar, line, scatter, etc)
    4. No explanations - only valid Python code
    
    User instructions: {instructions}
    
    Format exactly as:
    # Visualization 1
    [complete code]
    
    # Visualization 2
    [complete code]
    
    # Visualization 3
    [complete code]
    """

    response = model.generate_content(prompt)
    code_blocks = response.text.split("# Visualization ")[1:4]

    visualizations = []
    for i, block in enumerate(code_blocks, 1):
        try:
            # Clean and validate code
            cleaned_code = '\n'.join([
                line.strip() for line in block.split('\n') 
                if line.strip() and not line.startswith('```')
            ])
            
            # Create HD figure
            buf = io.BytesIO()
            plt.figure(figsize=(16, 9), dpi=120)
            plt.style.use('seaborn')
            
            # Execute generated code
            exec(cleaned_code, {'df': df, 'plt': plt})
            
            # Save HD image
            plt.tight_layout()
            plt.savefig(buf, format='png', bbox_inches='tight')
            plt.close()
            
            buf.seek(0)
            visualizations.append(Image.open(buf))
        except Exception as e:
            print(f"Visualization {i} Error: {str(e)}")
            visualizations.append(None)

    # Return exactly 3 images, filling with None if needed
    return visualizations + [None]*(3-len(visualizations))

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Data Visualization Tool")
    
    with gr.Row():
        api_key = gr.Textbox(label="Gemini API Key", type="password")
        file = gr.File(label="Upload CSV/Excel", file_types=[".csv", ".xlsx"])
    
    instructions = gr.Textbox(label="Instructions (optional)")
    submit = gr.Button("Generate Visualizations")
    
    with gr.Row():
        outputs = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]

    submit.click(
        process_file,
        inputs=[api_key, file, instructions],
        outputs=outputs
    )

demo.launch()