DataVisualize / app.py
mulasagg's picture
final
d1ae8d3
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO
import numpy as np
# Set the style for all plots - using a built-in style
plt.style.use('fivethirtyeight')
def configure_plot_style(fig, ax):
"""Configure common plot styling elements"""
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(True, linestyle='--', alpha=0.7)
fig.patch.set_facecolor('white')
ax.set_facecolor('white')
st.title("Interactive Dataset Plotting Tool")
# Upload Dataset
uploaded_file = st.file_uploader("Upload your CSV dataset", type=["csv"])
if uploaded_file:
try:
# Load dataset
df = pd.read_csv(uploaded_file)
st.write("Dataset Preview:")
st.dataframe(df)
# Plot type selection
plot_types = ["Line Plot", "Bar Plot", "Scatter Plot", "Histogram", "Box Plot", "Correlation Matrix"]
plot_type = st.selectbox("Select Plot Type:", plot_types)
# Color scheme selection
color_schemes = ['viridis', 'magma', 'plasma', 'inferno', 'cividis']
color_scheme = st.selectbox("Select Color Scheme:", color_schemes)
# Common figure creation
fig, ax = plt.subplots(figsize=(10, 6))
configure_plot_style(fig, ax)
if plot_type in ["Line Plot", "Bar Plot"]:
x_column = st.selectbox("Select X-axis column:", df.columns)
y_column = st.selectbox("Select Y-axis column:", df.columns)
if not pd.api.types.is_numeric_dtype(df[y_column]):
st.warning("Y-axis column must be numeric for this plot type.")
else:
if plot_type == "Line Plot":
ax.plot(df[x_column], df[y_column], marker='o', linewidth=2,
color=plt.cm.get_cmap(color_scheme)(0.6))
else:
ax.bar(df[x_column], df[y_column], color=plt.cm.get_cmap(color_scheme)(0.6))
ax.set_title(f"{plot_type} of {y_column} vs {x_column}", pad=20, fontsize=14)
ax.set_xlabel(x_column, fontsize=12)
ax.set_ylabel(y_column, fontsize=12)
plt.xticks(rotation=45 if len(df[x_column].unique()) > 10 else 0)
elif plot_type == "Scatter Plot":
x_column = st.selectbox("Select X-axis column:", df.columns)
y_column = st.selectbox("Select Y-axis column:", df.columns)
if not pd.api.types.is_numeric_dtype(df[x_column]) or not pd.api.types.is_numeric_dtype(df[y_column]):
st.warning("Both X and Y columns must be numeric for scatter plot.")
else:
scatter = ax.scatter(df[x_column], df[y_column],
c=np.arange(len(df)), cmap=color_scheme,
alpha=0.6, s=100)
plt.colorbar(scatter, ax=ax, label='Index')
ax.set_title(f"Scatter Plot of {y_column} vs {x_column}", pad=20, fontsize=14)
ax.set_xlabel(x_column, fontsize=12)
ax.set_ylabel(y_column, fontsize=12)
elif plot_type == "Histogram":
column = st.selectbox("Select column:", df.columns)
bins = st.slider("Number of bins:", min_value=5, max_value=50, value=20)
if not pd.api.types.is_numeric_dtype(df[column]):
st.warning("Column must be numeric for histogram.")
else:
n, bins, patches = ax.hist(df[column], bins=bins, edgecolor='white', linewidth=1)
for i, patch in enumerate(patches):
patch.set_facecolor(plt.cm.get_cmap(color_scheme)(i/len(patches)))
ax.set_title(f"Histogram of {column}", pad=20, fontsize=14)
ax.set_xlabel(column, fontsize=12)
ax.set_ylabel("Frequency", fontsize=12)
elif plot_type == "Box Plot":
x_column = st.selectbox("Select grouping column:", df.columns)
y_column = st.selectbox("Select value column:", df.columns)
if not pd.api.types.is_numeric_dtype(df[y_column]):
st.warning("Value column must be numeric for box plot.")
else:
box_plot = ax.boxplot([group[1][y_column].values for group in df.groupby(x_column)],
labels=df[x_column].unique(),
patch_artist=True)
# Color the boxes
colors = [plt.cm.get_cmap(color_scheme)(i/len(box_plot['boxes']))
for i in range(len(box_plot['boxes']))]
for patch, color in zip(box_plot['boxes'], colors):
patch.set_facecolor(color)
patch.set_alpha(0.7)
ax.set_title(f"Box Plot of {y_column} grouped by {x_column}", pad=20, fontsize=14)
ax.set_xlabel(x_column, fontsize=12)
ax.set_ylabel(y_column, fontsize=12)
plt.xticks(rotation=45 if len(df[x_column].unique()) > 10 else 0)
elif plot_type == "Correlation Matrix":
numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns
numeric_df = df[numeric_columns]
if len(numeric_columns) == 0:
st.warning("No numeric columns found in the dataset for correlation matrix.")
else:
corr = numeric_df.corr()
im = ax.imshow(corr, cmap=color_scheme)
plt.colorbar(im, ax=ax)
# Add correlation values
for i in range(len(corr)):
for j in range(len(corr)):
text = ax.text(j, i, f'{corr.iloc[i, j]:.2f}',
ha='center', va='center',
color='white' if abs(corr.iloc[i, j]) > 0.5 else 'black')
ax.set_xticks(range(len(corr.columns)))
ax.set_yticks(range(len(corr.columns)))
ax.set_xticklabels(corr.columns, rotation=45, ha='right')
ax.set_yticklabels(corr.columns)
ax.set_title("Correlation Matrix", pad=20, fontsize=14)
# Adjust layout and display plot
plt.tight_layout()
st.pyplot(fig)
# Download button
buffer = BytesIO()
plt.savefig(buffer, format="png", dpi=300, bbox_inches='tight')
buffer.seek(0)
st.download_button(
label="Download Plot as PNG",
data=buffer,
file_name="plot.png",
mime="image/png"
)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.info("Please make sure your dataset is properly formatted and contains appropriate data types for the selected plot type.")