Autonomous_Data_Scientist / src /visualizer.py
Megha Panicker
Resolve README.md conflict with HF Spaces metadata
c1b226b
"""
Data synthesis and visualization pipeline using Pandas, Matplotlib, and Seaborn.
Determines the best visual representation and saves high-resolution images.
Charts use a professional, dashboard-style look with currency formatting and clear typography.
"""
from pathlib import Path
from typing import Any, List, Optional
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd
import seaborn as sns
from src.models import ChartType, VisualizationConfig
# Professional palette: slate/blue tones (dashboard-style)
CHART_COLOR = "#334155"
CHART_COLORS = ["#475569", "#64748b", "#94a3b8", "#0ea5e9", "#0369a1", "#1e40af"]
MAX_BAR_CATEGORIES = 15 # cap bars so labels stay readable
def _is_currency_column(name: str) -> bool:
"""Heuristic: treat as currency if name suggests money."""
if not name:
return False
n = str(name).lower()
return any(k in n for k in ("amount", "total", "revenue", "sales", "sum", "price", "value"))
def _format_currency_axis(ax, axis="y"):
"""Format axis with K/M suffix and optional $ for large numbers."""
ax_to_use = ax.yaxis if axis == "y" else ax.xaxis
ax_to_use.set_major_formatter(
ticker.FuncFormatter(lambda x, p: f"${x/1e3:.0f}K" if abs(x) >= 1000 else f"${x:,.0f}")
)
def _setup_professional_style(ax, y_col=None):
"""Apply consistent style: spine, grid, font."""
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(axis="y", alpha=0.3, linestyle="-")
ax.set_axisbelow(True)
if y_col and _is_currency_column(y_col):
_format_currency_axis(ax, "y")
else:
ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, p: f"{x:,.0f}"))
def infer_visualization(data: List[dict]) -> VisualizationConfig:
"""
Infer the best visualization type from the data structure.
Uses heuristics: column count, types, and sample values.
"""
if not data:
return VisualizationConfig(
chart_type=ChartType.BAR,
title="No Data",
x_column=None,
y_column=None,
)
df = pd.DataFrame(data)
cols = list(df.columns)
numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
cat_cols = df.select_dtypes(include=["object", "category"]).columns.tolist()
# Heuristics for chart type
if len(numeric_cols) >= 2 and len(cat_cols) == 0:
return VisualizationConfig(
chart_type=ChartType.SCATTER,
title="Scatter Plot",
x_column=numeric_cols[0],
y_column=numeric_cols[1],
)
if len(cat_cols) >= 1 and len(numeric_cols) >= 1:
return VisualizationConfig(
chart_type=ChartType.BAR,
title=f"{numeric_cols[0]} by {cat_cols[0]}",
x_column=cat_cols[0],
y_column=numeric_cols[0],
)
if len(numeric_cols) == 1:
return VisualizationConfig(
chart_type=ChartType.HISTOGRAM,
title=f"Distribution of {numeric_cols[0]}",
x_column=numeric_cols[0],
y_column=None,
)
if len(cat_cols) >= 1:
return VisualizationConfig(
chart_type=ChartType.BAR,
title=f"Count of {cat_cols[0]}",
x_column=cat_cols[0],
y_column=None,
)
return VisualizationConfig(
chart_type=ChartType.BAR,
title="Data Overview",
x_column=cols[0] if cols else None,
y_column=cols[1] if len(cols) > 1 else None,
)
def create_visualization(
data: List[dict],
config: VisualizationConfig,
output_path: str = "output_chart.png",
) -> str:
"""
Create and save a high-resolution chart based on the config.
Returns the path to the saved image.
"""
df = pd.DataFrame(data)
if df.empty:
plt.figure(figsize=(8, 5))
plt.text(0.5, 0.5, "No data to visualize", ha="center", va="center", fontsize=14)
plt.axis("off")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
return output_path
sns.set_style("ticks")
plt.rc("font", size=11)
fig, ax = plt.subplots(figsize=(10, 6))
fig.patch.set_facecolor("white")
ax.set_facecolor("#fafafa")
chart_type = config.chart_type
x_col = config.x_column
y_col = config.y_column
if chart_type == ChartType.LINE and x_col and y_col:
df_plot = df.sort_values(x_col).head(100)
ax.plot(df_plot[x_col].astype(str), df_plot[y_col], color=CHART_COLOR, marker="o", markersize=5, linewidth=2)
_setup_professional_style(ax, y_col)
plt.xticks(rotation=45, ha="right")
elif chart_type == ChartType.BAR:
if y_col:
df_plot = df.groupby(x_col, as_index=False)[y_col].sum()
df_plot = df_plot.sort_values(y_col, ascending=False).head(MAX_BAR_CATEGORIES)
labels = [str(x)[:20] + ("…" if len(str(x)) > 20 else "") for x in df_plot[x_col]]
ax.bar(labels, df_plot[y_col], color=CHART_COLOR, edgecolor="white", linewidth=0.5)
else:
counts = df[x_col].value_counts().head(MAX_BAR_CATEGORIES)
labels = [str(x)[:20] + ("…" if len(str(x)) > 20 else "") for x in counts.index]
ax.bar(labels, counts.values, color=CHART_COLOR, edgecolor="white", linewidth=0.5)
_setup_professional_style(ax, y_col)
plt.xticks(rotation=45, ha="right")
elif chart_type == ChartType.BARH:
if y_col:
df_plot = df.groupby(x_col, as_index=False)[y_col].sum()
df_plot = df_plot.sort_values(y_col, ascending=True).head(MAX_BAR_CATEGORIES)
labels = [str(x)[:25] + ("…" if len(str(x)) > 25 else "") for x in df_plot[x_col]]
ax.barh(labels, df_plot[y_col], color=CHART_COLOR, edgecolor="white", linewidth=0.5)
else:
counts = df[x_col].value_counts().head(MAX_BAR_CATEGORIES)
counts = counts.sort_values(ascending=True)
labels = [str(x)[:25] + ("…" if len(str(x)) > 25 else "") for x in counts.index]
ax.barh(labels, counts.values, color=CHART_COLOR, edgecolor="white", linewidth=0.5)
_setup_professional_style(ax, y_col)
elif chart_type == ChartType.PIE and x_col:
counts = df[x_col].value_counts().head(10)
ax.pie(counts.values, labels=counts.index, autopct="%1.1f%%", colors=CHART_COLORS, startangle=90)
ax.axis("equal")
elif chart_type == ChartType.SCATTER and x_col and y_col:
ax.scatter(df[x_col], df[y_col], color=CHART_COLOR, alpha=0.6, s=40)
_setup_professional_style(ax, y_col)
elif chart_type == ChartType.HISTOGRAM and x_col:
ax.hist(df[x_col].dropna(), bins=min(25, len(df)), color=CHART_COLOR, edgecolor="white")
_setup_professional_style(ax)
elif chart_type == ChartType.BOX and (x_col or y_col):
if x_col and y_col:
df.boxplot(column=y_col, by=x_col, ax=ax)
elif y_col:
df.boxplot(column=y_col, ax=ax)
_setup_professional_style(ax, y_col)
elif chart_type == ChartType.HEATMAP:
numeric = df.select_dtypes(include=["number"])
if len(numeric.columns) >= 2:
sns.heatmap(numeric.corr(), annot=True, cmap="Blues", ax=ax, fmt=".2f")
else:
cols = list(df.columns)[:2]
if len(cols) == 2:
y_vals = df[cols[1]] if pd.api.types.is_numeric_dtype(df[cols[1]]) else list(range(len(df)))
labels = [str(x)[:20] + ("…" if len(str(x)) > 20 else "") for x in df[cols[0]]]
ax.bar(labels, y_vals, color=CHART_COLOR, edgecolor="white", linewidth=0.5)
_setup_professional_style(ax, cols[1] if pd.api.types.is_numeric_dtype(df[cols[1]]) else None)
plt.xticks(rotation=45, ha="right")
ax.set_title(config.title, fontsize=14, fontweight="600", color="#1e293b")
plt.tight_layout()
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches="tight", facecolor="white")
plt.close()
return output_path