Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
import numpy as np | |
# Set the style for all plots - using a built-in style | |
plt.style.use('fivethirtyeight') | |
def configure_plot_style(fig, ax): | |
"""Configure common plot styling elements""" | |
ax.spines['top'].set_visible(False) | |
ax.spines['right'].set_visible(False) | |
ax.grid(True, linestyle='--', alpha=0.7) | |
fig.patch.set_facecolor('white') | |
ax.set_facecolor('white') | |
st.title("Interactive Dataset Plotting Tool") | |
# Upload Dataset | |
uploaded_file = st.file_uploader("Upload your CSV dataset", type=["csv"]) | |
if uploaded_file: | |
try: | |
# Load dataset | |
df = pd.read_csv(uploaded_file) | |
st.write("Dataset Preview:") | |
st.dataframe(df) | |
# Plot type selection | |
plot_types = ["Line Plot", "Bar Plot", "Scatter Plot", "Histogram", "Box Plot", "Correlation Matrix"] | |
plot_type = st.selectbox("Select Plot Type:", plot_types) | |
# Color scheme selection | |
color_schemes = ['viridis', 'magma', 'plasma', 'inferno', 'cividis'] | |
color_scheme = st.selectbox("Select Color Scheme:", color_schemes) | |
# Common figure creation | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
configure_plot_style(fig, ax) | |
if plot_type in ["Line Plot", "Bar Plot"]: | |
x_column = st.selectbox("Select X-axis column:", df.columns) | |
y_column = st.selectbox("Select Y-axis column:", df.columns) | |
if not pd.api.types.is_numeric_dtype(df[y_column]): | |
st.warning("Y-axis column must be numeric for this plot type.") | |
else: | |
if plot_type == "Line Plot": | |
ax.plot(df[x_column], df[y_column], marker='o', linewidth=2, | |
color=plt.cm.get_cmap(color_scheme)(0.6)) | |
else: | |
ax.bar(df[x_column], df[y_column], color=plt.cm.get_cmap(color_scheme)(0.6)) | |
ax.set_title(f"{plot_type} of {y_column} vs {x_column}", pad=20, fontsize=14) | |
ax.set_xlabel(x_column, fontsize=12) | |
ax.set_ylabel(y_column, fontsize=12) | |
plt.xticks(rotation=45 if len(df[x_column].unique()) > 10 else 0) | |
elif plot_type == "Scatter Plot": | |
x_column = st.selectbox("Select X-axis column:", df.columns) | |
y_column = st.selectbox("Select Y-axis column:", df.columns) | |
if not pd.api.types.is_numeric_dtype(df[x_column]) or not pd.api.types.is_numeric_dtype(df[y_column]): | |
st.warning("Both X and Y columns must be numeric for scatter plot.") | |
else: | |
scatter = ax.scatter(df[x_column], df[y_column], | |
c=np.arange(len(df)), cmap=color_scheme, | |
alpha=0.6, s=100) | |
plt.colorbar(scatter, ax=ax, label='Index') | |
ax.set_title(f"Scatter Plot of {y_column} vs {x_column}", pad=20, fontsize=14) | |
ax.set_xlabel(x_column, fontsize=12) | |
ax.set_ylabel(y_column, fontsize=12) | |
elif plot_type == "Histogram": | |
column = st.selectbox("Select column:", df.columns) | |
bins = st.slider("Number of bins:", min_value=5, max_value=50, value=20) | |
if not pd.api.types.is_numeric_dtype(df[column]): | |
st.warning("Column must be numeric for histogram.") | |
else: | |
n, bins, patches = ax.hist(df[column], bins=bins, edgecolor='white', linewidth=1) | |
for i, patch in enumerate(patches): | |
patch.set_facecolor(plt.cm.get_cmap(color_scheme)(i/len(patches))) | |
ax.set_title(f"Histogram of {column}", pad=20, fontsize=14) | |
ax.set_xlabel(column, fontsize=12) | |
ax.set_ylabel("Frequency", fontsize=12) | |
elif plot_type == "Box Plot": | |
x_column = st.selectbox("Select grouping column:", df.columns) | |
y_column = st.selectbox("Select value column:", df.columns) | |
if not pd.api.types.is_numeric_dtype(df[y_column]): | |
st.warning("Value column must be numeric for box plot.") | |
else: | |
box_plot = ax.boxplot([group[1][y_column].values for group in df.groupby(x_column)], | |
labels=df[x_column].unique(), | |
patch_artist=True) | |
# Color the boxes | |
colors = [plt.cm.get_cmap(color_scheme)(i/len(box_plot['boxes'])) | |
for i in range(len(box_plot['boxes']))] | |
for patch, color in zip(box_plot['boxes'], colors): | |
patch.set_facecolor(color) | |
patch.set_alpha(0.7) | |
ax.set_title(f"Box Plot of {y_column} grouped by {x_column}", pad=20, fontsize=14) | |
ax.set_xlabel(x_column, fontsize=12) | |
ax.set_ylabel(y_column, fontsize=12) | |
plt.xticks(rotation=45 if len(df[x_column].unique()) > 10 else 0) | |
elif plot_type == "Correlation Matrix": | |
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns | |
numeric_df = df[numeric_columns] | |
if len(numeric_columns) == 0: | |
st.warning("No numeric columns found in the dataset for correlation matrix.") | |
else: | |
corr = numeric_df.corr() | |
im = ax.imshow(corr, cmap=color_scheme) | |
plt.colorbar(im, ax=ax) | |
# Add correlation values | |
for i in range(len(corr)): | |
for j in range(len(corr)): | |
text = ax.text(j, i, f'{corr.iloc[i, j]:.2f}', | |
ha='center', va='center', | |
color='white' if abs(corr.iloc[i, j]) > 0.5 else 'black') | |
ax.set_xticks(range(len(corr.columns))) | |
ax.set_yticks(range(len(corr.columns))) | |
ax.set_xticklabels(corr.columns, rotation=45, ha='right') | |
ax.set_yticklabels(corr.columns) | |
ax.set_title("Correlation Matrix", pad=20, fontsize=14) | |
# Adjust layout and display plot | |
plt.tight_layout() | |
st.pyplot(fig) | |
# Download button | |
buffer = BytesIO() | |
plt.savefig(buffer, format="png", dpi=300, bbox_inches='tight') | |
buffer.seek(0) | |
st.download_button( | |
label="Download Plot as PNG", | |
data=buffer, | |
file_name="plot.png", | |
mime="image/png" | |
) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
st.info("Please make sure your dataset is properly formatted and contains appropriate data types for the selected plot type.") |