Spaces:
Sleeping
Sleeping
File size: 6,982 Bytes
d1ae8d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO
import numpy as np
# Set the style for all plots - using a built-in style
plt.style.use('fivethirtyeight')
def configure_plot_style(fig, ax):
"""Configure common plot styling elements"""
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(True, linestyle='--', alpha=0.7)
fig.patch.set_facecolor('white')
ax.set_facecolor('white')
st.title("Interactive Dataset Plotting Tool")
# Upload Dataset
uploaded_file = st.file_uploader("Upload your CSV dataset", type=["csv"])
if uploaded_file:
try:
# Load dataset
df = pd.read_csv(uploaded_file)
st.write("Dataset Preview:")
st.dataframe(df)
# Plot type selection
plot_types = ["Line Plot", "Bar Plot", "Scatter Plot", "Histogram", "Box Plot", "Correlation Matrix"]
plot_type = st.selectbox("Select Plot Type:", plot_types)
# Color scheme selection
color_schemes = ['viridis', 'magma', 'plasma', 'inferno', 'cividis']
color_scheme = st.selectbox("Select Color Scheme:", color_schemes)
# Common figure creation
fig, ax = plt.subplots(figsize=(10, 6))
configure_plot_style(fig, ax)
if plot_type in ["Line Plot", "Bar Plot"]:
x_column = st.selectbox("Select X-axis column:", df.columns)
y_column = st.selectbox("Select Y-axis column:", df.columns)
if not pd.api.types.is_numeric_dtype(df[y_column]):
st.warning("Y-axis column must be numeric for this plot type.")
else:
if plot_type == "Line Plot":
ax.plot(df[x_column], df[y_column], marker='o', linewidth=2,
color=plt.cm.get_cmap(color_scheme)(0.6))
else:
ax.bar(df[x_column], df[y_column], color=plt.cm.get_cmap(color_scheme)(0.6))
ax.set_title(f"{plot_type} of {y_column} vs {x_column}", pad=20, fontsize=14)
ax.set_xlabel(x_column, fontsize=12)
ax.set_ylabel(y_column, fontsize=12)
plt.xticks(rotation=45 if len(df[x_column].unique()) > 10 else 0)
elif plot_type == "Scatter Plot":
x_column = st.selectbox("Select X-axis column:", df.columns)
y_column = st.selectbox("Select Y-axis column:", df.columns)
if not pd.api.types.is_numeric_dtype(df[x_column]) or not pd.api.types.is_numeric_dtype(df[y_column]):
st.warning("Both X and Y columns must be numeric for scatter plot.")
else:
scatter = ax.scatter(df[x_column], df[y_column],
c=np.arange(len(df)), cmap=color_scheme,
alpha=0.6, s=100)
plt.colorbar(scatter, ax=ax, label='Index')
ax.set_title(f"Scatter Plot of {y_column} vs {x_column}", pad=20, fontsize=14)
ax.set_xlabel(x_column, fontsize=12)
ax.set_ylabel(y_column, fontsize=12)
elif plot_type == "Histogram":
column = st.selectbox("Select column:", df.columns)
bins = st.slider("Number of bins:", min_value=5, max_value=50, value=20)
if not pd.api.types.is_numeric_dtype(df[column]):
st.warning("Column must be numeric for histogram.")
else:
n, bins, patches = ax.hist(df[column], bins=bins, edgecolor='white', linewidth=1)
for i, patch in enumerate(patches):
patch.set_facecolor(plt.cm.get_cmap(color_scheme)(i/len(patches)))
ax.set_title(f"Histogram of {column}", pad=20, fontsize=14)
ax.set_xlabel(column, fontsize=12)
ax.set_ylabel("Frequency", fontsize=12)
elif plot_type == "Box Plot":
x_column = st.selectbox("Select grouping column:", df.columns)
y_column = st.selectbox("Select value column:", df.columns)
if not pd.api.types.is_numeric_dtype(df[y_column]):
st.warning("Value column must be numeric for box plot.")
else:
box_plot = ax.boxplot([group[1][y_column].values for group in df.groupby(x_column)],
labels=df[x_column].unique(),
patch_artist=True)
# Color the boxes
colors = [plt.cm.get_cmap(color_scheme)(i/len(box_plot['boxes']))
for i in range(len(box_plot['boxes']))]
for patch, color in zip(box_plot['boxes'], colors):
patch.set_facecolor(color)
patch.set_alpha(0.7)
ax.set_title(f"Box Plot of {y_column} grouped by {x_column}", pad=20, fontsize=14)
ax.set_xlabel(x_column, fontsize=12)
ax.set_ylabel(y_column, fontsize=12)
plt.xticks(rotation=45 if len(df[x_column].unique()) > 10 else 0)
elif plot_type == "Correlation Matrix":
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns
numeric_df = df[numeric_columns]
if len(numeric_columns) == 0:
st.warning("No numeric columns found in the dataset for correlation matrix.")
else:
corr = numeric_df.corr()
im = ax.imshow(corr, cmap=color_scheme)
plt.colorbar(im, ax=ax)
# Add correlation values
for i in range(len(corr)):
for j in range(len(corr)):
text = ax.text(j, i, f'{corr.iloc[i, j]:.2f}',
ha='center', va='center',
color='white' if abs(corr.iloc[i, j]) > 0.5 else 'black')
ax.set_xticks(range(len(corr.columns)))
ax.set_yticks(range(len(corr.columns)))
ax.set_xticklabels(corr.columns, rotation=45, ha='right')
ax.set_yticklabels(corr.columns)
ax.set_title("Correlation Matrix", pad=20, fontsize=14)
# Adjust layout and display plot
plt.tight_layout()
st.pyplot(fig)
# Download button
buffer = BytesIO()
plt.savefig(buffer, format="png", dpi=300, bbox_inches='tight')
buffer.seek(0)
st.download_button(
label="Download Plot as PNG",
data=buffer,
file_name="plot.png",
mime="image/png"
)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.info("Please make sure your dataset is properly formatted and contains appropriate data types for the selected plot type.") |