File size: 3,079 Bytes
4fc79a4
ff768e2
 
4fc79a4
 
29f87f1
4fc79a4
 
69a028d
 
4fc79a4
69a028d
 
 
 
 
 
 
 
a888f55
69a028d
 
 
 
29f87f1
69a028d
 
 
 
4fc79a4
a888f55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image

def process_file(api_key, file, instructions):
    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-2.5-pro-latest')

    try:
        if file.name.endswith('.csv'):
            df = pd.read_csv(file.name)
        else:
            df = pd.read_excel(file.name)
    except Exception as e:
        return [f"File Error: {str(e)}"] * 3

    # Properly terminated prompt string
    prompt = f"""Generate exactly 3 distinct matplotlib visualizations for:
    Columns: {list(df.columns)}
    Data types: {dict(df.dtypes)}
    Sample data: {df.head(3).to_dict()}
    
    Requirements:
    1. 1920x1080 resolution (figsize=(16,9), dpi=120)
    2. Professional styling (seaborn, grid, proper labels)
    3. Diverse chart types (include at least 1 advanced visualization)
    
    User instructions: {instructions or 'None provided'}
    
    Format response strictly as:
    # Visualization 1
    plt.figure(figsize=(16,9), dpi=120)
    [code]
    plt.tight_layout()
    
    # Visualization 2
    ...
    """  # Closing triple quotes added here

    response = model.generate_content(prompt)
    code_blocks = response.text.split("# Visualization ")[1:4]

    visualizations = []
    for i, block in enumerate(code_blocks, 1):
        try:
            plt.figure(figsize=(16, 9), dpi=120)
            plt.style.use('seaborn')
            
            cleaned_code = '\n'.join([line.strip() for line in block.split('\n')[1:] if line.strip()])
            exec(cleaned_code, {'df': df, 'plt': plt})
            plt.title(f"Visualization {i}", fontsize=14)
            plt.tight_layout()

            buf = io.BytesIO()
            plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
            plt.close()
            buf.seek(0)
            visualizations.append(Image.open(buf))
        except Exception as e:
            print(f"Visualization {i} failed: {str(e)}")
            visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))

    while len(visualizations) < 3:
        visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))

    return visualizations[:3]

# Gradio interface
with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
    gr.Markdown("# **HD Data Visualizer**  πŸ“Šβœ¨")
    
    with gr.Row():
        api_key = gr.Textbox(label="πŸ”‘ Gemini API Key", type="password")
        file = gr.File(label="πŸ“ Upload Dataset", file_types=[".csv", ".xlsx"])
    
    instructions = gr.Textbox(label="πŸ’‘ Custom Instructions (optional)", 
                            placeholder="E.g.: Focus on time series patterns...")
    submit = gr.Button("πŸš€ Generate Visualizations", variant="primary")
    
    with gr.Row():
        outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]

    submit.click(
        process_file,
        inputs=[api_key, file, instructions],
        outputs=outputs
    )

demo.launch()