|
from typing import Dict, List, Any, Optional, Tuple, Union
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib
|
|
import io
|
|
import base64
|
|
import numpy as np
|
|
from llama_index.tools import FunctionTool
|
|
from pathlib import Path
|
|
|
|
|
|
matplotlib.use('Agg')
|
|
|
|
class VisualizationTools:
|
|
"""Tools for creating visualizations from CSV data."""
|
|
|
|
def __init__(self, csv_directory: str):
|
|
"""Initialize with directory containing CSV files."""
|
|
self.csv_directory = csv_directory
|
|
self.dataframes = {}
|
|
self.tools = self._create_tools()
|
|
self.figure_size = (10, 6)
|
|
self.dpi = 100
|
|
|
|
def _load_dataframe(self, filename: str) -> pd.DataFrame:
|
|
"""Load a CSV file as DataFrame, with caching."""
|
|
if filename not in self.dataframes:
|
|
file_path = Path(self.csv_directory) / filename
|
|
if not file_path.exists() and not filename.endswith('.csv'):
|
|
file_path = Path(self.csv_directory) / f"{filename}.csv"
|
|
|
|
if file_path.exists():
|
|
self.dataframes[filename] = pd.read_csv(file_path)
|
|
else:
|
|
raise ValueError(f"CSV file not found: {filename}")
|
|
|
|
return self.dataframes[filename]
|
|
|
|
def _create_tools(self) -> List[FunctionTool]:
|
|
"""Create LlamaIndex function tools for visualizations."""
|
|
tools = [
|
|
FunctionTool.from_defaults(
|
|
name="create_line_chart",
|
|
description="Create a line chart from CSV data",
|
|
fn=self.create_line_chart
|
|
),
|
|
FunctionTool.from_defaults(
|
|
name="create_bar_chart",
|
|
description="Create a bar chart from CSV data",
|
|
fn=self.create_bar_chart
|
|
),
|
|
FunctionTool.from_defaults(
|
|
name="create_scatter_plot",
|
|
description="Create a scatter plot from CSV data",
|
|
fn=self.create_scatter_plot
|
|
),
|
|
FunctionTool.from_defaults(
|
|
name="create_histogram",
|
|
description="Create a histogram from CSV data",
|
|
fn=self.create_histogram
|
|
),
|
|
FunctionTool.from_defaults(
|
|
name="create_pie_chart",
|
|
description="Create a pie chart from CSV data",
|
|
fn=self.create_pie_chart
|
|
)
|
|
]
|
|
return tools
|
|
|
|
def get_tools(self) -> List[FunctionTool]:
|
|
"""Get all available visualization tools."""
|
|
return self.tools
|
|
|
|
def _figure_to_base64(self, fig) -> str:
|
|
"""Convert matplotlib figure to base64 encoded string."""
|
|
buf = io.BytesIO()
|
|
fig.savefig(buf, format='png', dpi=self.dpi)
|
|
buf.seek(0)
|
|
img_str = base64.b64encode(buf.read()).decode('utf-8')
|
|
plt.close(fig)
|
|
return img_str
|
|
|
|
|
|
def create_line_chart(self, filename: str, x_column: str, y_column: str,
|
|
title: str = None, limit: int = 50) -> Dict[str, Any]:
|
|
"""Create a line chart visualization."""
|
|
df = self._load_dataframe(filename)
|
|
|
|
|
|
if len(df) > limit:
|
|
df = df.head(limit)
|
|
|
|
fig, ax = plt.subplots(figsize=self.figure_size)
|
|
|
|
|
|
ax.plot(df[x_column], df[y_column], marker='o', linestyle='-')
|
|
|
|
|
|
ax.set_xlabel(x_column)
|
|
ax.set_ylabel(y_column)
|
|
ax.set_title(title or f"{y_column} vs {x_column}")
|
|
ax.grid(True)
|
|
|
|
|
|
img_str = self._figure_to_base64(fig)
|
|
|
|
return {
|
|
"chart_type": "line",
|
|
"x_column": x_column,
|
|
"y_column": y_column,
|
|
"data_points": len(df),
|
|
"image": img_str
|
|
}
|
|
|
|
def create_bar_chart(self, filename: str, x_column: str, y_column: str,
|
|
title: str = None, limit: int = 20) -> Dict[str, Any]:
|
|
"""Create a bar chart visualization."""
|
|
df = self._load_dataframe(filename)
|
|
|
|
|
|
if len(df) > limit:
|
|
df = df.head(limit)
|
|
|
|
fig, ax = plt.subplots(figsize=self.figure_size)
|
|
|
|
|
|
ax.bar(df[x_column], df[y_column])
|
|
|
|
|
|
ax.set_xlabel(x_column)
|
|
ax.set_ylabel(y_column)
|
|
ax.set_title(title or f"{y_column} by {x_column}")
|
|
|
|
|
|
if len(df) > 5:
|
|
plt.xticks(rotation=45, ha='right')
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
img_str = self._figure_to_base64(fig)
|
|
|
|
return {
|
|
"chart_type": "bar",
|
|
"x_column": x_column,
|
|
"y_column": y_column,
|
|
"categories": len(df),
|
|
"image": img_str
|
|
}
|
|
|
|
def create_scatter_plot(self, filename: str, x_column: str, y_column: str,
|
|
color_column: str = None, title: str = None) -> Dict[str, Any]:
|
|
"""Create a scatter plot visualization."""
|
|
df = self._load_dataframe(filename)
|
|
|
|
fig, ax = plt.subplots(figsize=self.figure_size)
|
|
|
|
|
|
if color_column and color_column in df.columns:
|
|
scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7)
|
|
plt.colorbar(scatter, ax=ax, label=color_column)
|
|
else:
|
|
ax.scatter(df[x_column], df[y_column], alpha=0.7)
|
|
|
|
|
|
ax.set_xlabel(x_column)
|
|
ax.set_ylabel(y_column)
|
|
ax.set_title(title or f"{y_column} vs {x_column}")
|
|
ax.grid(True, linestyle='--', alpha=0.7)
|
|
|
|
|
|
img_str = self._figure_to_base64(fig)
|
|
|
|
return {
|
|
"chart_type": "scatter",
|
|
"x_column": x_column,
|
|
"y_column": y_column,
|
|
"color_column": color_column,
|
|
"data_points": len(df),
|
|
"image": img_str
|
|
}
|
|
|
|
def create_histogram(self, filename: str, column: str, bins: int = 10,
|
|
title: str = None) -> Dict[str, Any]:
|
|
"""Create a histogram visualization."""
|
|
df = self._load_dataframe(filename)
|
|
|
|
fig, ax = plt.subplots(figsize=self.figure_size)
|
|
|
|
|
|
ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black')
|
|
|
|
|
|
ax.set_xlabel(column)
|
|
ax.set_ylabel('Frequency')
|
|
ax.set_title(title or f"Distribution of {column}")
|
|
ax.grid(True, linestyle='--', alpha=0.7)
|
|
|
|
|
|
img_str = self._figure_to_base64(fig)
|
|
|
|
return {
|
|
"chart_type": "histogram",
|
|
"column": column,
|
|
"bins": bins,
|
|
"data_points": len(df),
|
|
"image": img_str
|
|
}
|
|
|
|
def create_pie_chart(self, filename: str, label_column: str, value_column: str = None,
|
|
title: str = None, limit: int = 10) -> Dict[str, Any]:
|
|
"""Create a pie chart visualization."""
|
|
df = self._load_dataframe(filename)
|
|
|
|
|
|
if value_column is None:
|
|
data = df[label_column].value_counts().head(limit)
|
|
labels = data.index.tolist()
|
|
values = data.values.tolist()
|
|
else:
|
|
|
|
grouped = df.groupby(label_column)[value_column].sum().reset_index()
|
|
|
|
grouped = grouped.nlargest(limit, value_column)
|
|
labels = grouped[label_column].tolist()
|
|
values = grouped[value_column].tolist()
|
|
|
|
fig, ax = plt.subplots(figsize=self.figure_size)
|
|
|
|
|
|
ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True)
|
|
ax.axis('equal')
|
|
|
|
|
|
ax.set_title(title or f"Distribution of {label_column}")
|
|
|
|
|
|
img_str = self._figure_to_base64(fig)
|
|
|
|
return {
|
|
"chart_type": "pie",
|
|
"label_column": label_column,
|
|
"value_column": value_column,
|
|
"categories": len(labels),
|
|
"image": img_str
|
|
}
|
|
|