File size: 6,982 Bytes
d1ae8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO
import numpy as np

# Set the style for all plots - using a built-in style
plt.style.use('fivethirtyeight')  

def configure_plot_style(fig, ax):
    """Configure common plot styling elements"""
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.grid(True, linestyle='--', alpha=0.7)
    fig.patch.set_facecolor('white')
    ax.set_facecolor('white')


st.title("Interactive Dataset Plotting Tool")

# Upload Dataset
uploaded_file = st.file_uploader("Upload your CSV dataset", type=["csv"])

if uploaded_file:
    try:
        # Load dataset
        df = pd.read_csv(uploaded_file)
        st.write("Dataset Preview:")
        st.dataframe(df)

        # Plot type selection
        plot_types = ["Line Plot", "Bar Plot", "Scatter Plot", "Histogram", "Box Plot", "Correlation Matrix"]
        plot_type = st.selectbox("Select Plot Type:", plot_types)

        # Color scheme selection
        color_schemes = ['viridis', 'magma', 'plasma', 'inferno', 'cividis']
        color_scheme = st.selectbox("Select Color Scheme:", color_schemes)

        # Common figure creation
        fig, ax = plt.subplots(figsize=(10, 6))
        configure_plot_style(fig, ax)

        if plot_type in ["Line Plot", "Bar Plot"]:
            x_column = st.selectbox("Select X-axis column:", df.columns)
            y_column = st.selectbox("Select Y-axis column:", df.columns)

            if not pd.api.types.is_numeric_dtype(df[y_column]):
                st.warning("Y-axis column must be numeric for this plot type.")
            else:
                if plot_type == "Line Plot":
                    ax.plot(df[x_column], df[y_column], marker='o', linewidth=2, 
                           color=plt.cm.get_cmap(color_scheme)(0.6))
                else:  
                    ax.bar(df[x_column], df[y_column], color=plt.cm.get_cmap(color_scheme)(0.6))
                    
                    
                ax.set_title(f"{plot_type} of {y_column} vs {x_column}", pad=20, fontsize=14)
                ax.set_xlabel(x_column, fontsize=12)
                ax.set_ylabel(y_column, fontsize=12)
                plt.xticks(rotation=45 if len(df[x_column].unique()) > 10 else 0)

        elif plot_type == "Scatter Plot":
            x_column = st.selectbox("Select X-axis column:", df.columns)
            y_column = st.selectbox("Select Y-axis column:", df.columns)
            
            if not pd.api.types.is_numeric_dtype(df[x_column]) or not pd.api.types.is_numeric_dtype(df[y_column]):
                st.warning("Both X and Y columns must be numeric for scatter plot.")
            else:
                scatter = ax.scatter(df[x_column], df[y_column], 
                                   c=np.arange(len(df)), cmap=color_scheme, 
                                   alpha=0.6, s=100)
                plt.colorbar(scatter, ax=ax, label='Index')
                ax.set_title(f"Scatter Plot of {y_column} vs {x_column}", pad=20, fontsize=14)
                ax.set_xlabel(x_column, fontsize=12)
                ax.set_ylabel(y_column, fontsize=12)

        elif plot_type == "Histogram":
            column = st.selectbox("Select column:", df.columns)
            bins = st.slider("Number of bins:", min_value=5, max_value=50, value=20)

            if not pd.api.types.is_numeric_dtype(df[column]):
                st.warning("Column must be numeric for histogram.")
            else:
                n, bins, patches = ax.hist(df[column], bins=bins, edgecolor='white', linewidth=1)
                for i, patch in enumerate(patches):
                    patch.set_facecolor(plt.cm.get_cmap(color_scheme)(i/len(patches)))
                
                ax.set_title(f"Histogram of {column}", pad=20, fontsize=14)
                ax.set_xlabel(column, fontsize=12)
                ax.set_ylabel("Frequency", fontsize=12)

        elif plot_type == "Box Plot":
            x_column = st.selectbox("Select grouping column:", df.columns)
            y_column = st.selectbox("Select value column:", df.columns)

            if not pd.api.types.is_numeric_dtype(df[y_column]):
                st.warning("Value column must be numeric for box plot.")
            else:
                box_plot = ax.boxplot([group[1][y_column].values for group in df.groupby(x_column)],
                                    labels=df[x_column].unique(),
                                    patch_artist=True)
                
                # Color the boxes
                colors = [plt.cm.get_cmap(color_scheme)(i/len(box_plot['boxes'])) 
                         for i in range(len(box_plot['boxes']))]
                for patch, color in zip(box_plot['boxes'], colors):
                    patch.set_facecolor(color)
                    patch.set_alpha(0.7)
                
                ax.set_title(f"Box Plot of {y_column} grouped by {x_column}", pad=20, fontsize=14)
                ax.set_xlabel(x_column, fontsize=12)
                ax.set_ylabel(y_column, fontsize=12)
                plt.xticks(rotation=45 if len(df[x_column].unique()) > 10 else 0)

        elif plot_type == "Correlation Matrix":
            numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns
            numeric_df = df[numeric_columns]
            
            if len(numeric_columns) == 0:
                st.warning("No numeric columns found in the dataset for correlation matrix.")
            else:
                corr = numeric_df.corr()
                im = ax.imshow(corr, cmap=color_scheme)
                plt.colorbar(im, ax=ax)
                
                # Add correlation values
                for i in range(len(corr)):
                    for j in range(len(corr)):
                        text = ax.text(j, i, f'{corr.iloc[i, j]:.2f}',
                                     ha='center', va='center',
                                     color='white' if abs(corr.iloc[i, j]) > 0.5 else 'black')
                
                ax.set_xticks(range(len(corr.columns)))
                ax.set_yticks(range(len(corr.columns)))
                ax.set_xticklabels(corr.columns, rotation=45, ha='right')
                ax.set_yticklabels(corr.columns)
                ax.set_title("Correlation Matrix", pad=20, fontsize=14)

        # Adjust layout and display plot
        plt.tight_layout()
        st.pyplot(fig)
        
        # Download button
        buffer = BytesIO()
        plt.savefig(buffer, format="png", dpi=300, bbox_inches='tight')
        buffer.seek(0)
        st.download_button(
            label="Download Plot as PNG",
            data=buffer,
            file_name="plot.png",
            mime="image/png"
        )

    except Exception as e:
        st.error(f"An error occurred: {str(e)}")
        st.info("Please make sure your dataset is properly formatted and contains appropriate data types for the selected plot type.")