from typing import Dict, List, Any, Optional, Tuple, Union import pandas as pd import matplotlib.pyplot as plt import matplotlib import io import base64 import numpy as np from llama_index.tools import FunctionTool from pathlib import Path # Configure matplotlib for non-interactive environments matplotlib.use('Agg') class VisualizationTools: """Tools for creating visualizations from CSV data.""" def __init__(self, csv_directory: str): """Initialize with directory containing CSV files.""" self.csv_directory = csv_directory self.dataframes = {} self.tools = self._create_tools() self.figure_size = (10, 6) self.dpi = 100 def _load_dataframe(self, filename: str) -> pd.DataFrame: """Load a CSV file as DataFrame, with caching.""" if filename not in self.dataframes: file_path = Path(self.csv_directory) / filename if not file_path.exists() and not filename.endswith('.csv'): file_path = Path(self.csv_directory) / f"{filename}.csv" if file_path.exists(): self.dataframes[filename] = pd.read_csv(file_path) else: raise ValueError(f"CSV file not found: {filename}") return self.dataframes[filename] def _create_tools(self) -> List[FunctionTool]: """Create LlamaIndex function tools for visualizations.""" tools = [ FunctionTool.from_defaults( name="create_line_chart", description="Create a line chart from CSV data", fn=self.create_line_chart ), FunctionTool.from_defaults( name="create_bar_chart", description="Create a bar chart from CSV data", fn=self.create_bar_chart ), FunctionTool.from_defaults( name="create_scatter_plot", description="Create a scatter plot from CSV data", fn=self.create_scatter_plot ), FunctionTool.from_defaults( name="create_histogram", description="Create a histogram from CSV data", fn=self.create_histogram ), FunctionTool.from_defaults( name="create_pie_chart", description="Create a pie chart from CSV data", fn=self.create_pie_chart ) ] return tools def get_tools(self) -> List[FunctionTool]: """Get all available visualization tools.""" return self.tools def _figure_to_base64(self, fig) -> str: """Convert matplotlib figure to base64 encoded string.""" buf = io.BytesIO() fig.savefig(buf, format='png', dpi=self.dpi) buf.seek(0) img_str = base64.b64encode(buf.read()).decode('utf-8') plt.close(fig) return img_str # Visualization tool implementations def create_line_chart(self, filename: str, x_column: str, y_column: str, title: str = None, limit: int = 50) -> Dict[str, Any]: """Create a line chart visualization.""" df = self._load_dataframe(filename) # Limit data points if needed if len(df) > limit: df = df.head(limit) fig, ax = plt.subplots(figsize=self.figure_size) # Create line chart ax.plot(df[x_column], df[y_column], marker='o', linestyle='-') # Set labels and title ax.set_xlabel(x_column) ax.set_ylabel(y_column) ax.set_title(title or f"{y_column} vs {x_column}") ax.grid(True) # Convert to base64 img_str = self._figure_to_base64(fig) return { "chart_type": "line", "x_column": x_column, "y_column": y_column, "data_points": len(df), "image": img_str } def create_bar_chart(self, filename: str, x_column: str, y_column: str, title: str = None, limit: int = 20) -> Dict[str, Any]: """Create a bar chart visualization.""" df = self._load_dataframe(filename) # Limit categories if needed if len(df) > limit: df = df.head(limit) fig, ax = plt.subplots(figsize=self.figure_size) # Create bar chart ax.bar(df[x_column], df[y_column]) # Set labels and title ax.set_xlabel(x_column) ax.set_ylabel(y_column) ax.set_title(title or f"{y_column} by {x_column}") # Rotate x labels if there are many categories if len(df) > 5: plt.xticks(rotation=45, ha='right') plt.tight_layout() # Convert to base64 img_str = self._figure_to_base64(fig) return { "chart_type": "bar", "x_column": x_column, "y_column": y_column, "categories": len(df), "image": img_str } def create_scatter_plot(self, filename: str, x_column: str, y_column: str, color_column: str = None, title: str = None) -> Dict[str, Any]: """Create a scatter plot visualization.""" df = self._load_dataframe(filename) fig, ax = plt.subplots(figsize=self.figure_size) # Create scatter plot if color_column and color_column in df.columns: scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7) plt.colorbar(scatter, ax=ax, label=color_column) else: ax.scatter(df[x_column], df[y_column], alpha=0.7) # Set labels and title ax.set_xlabel(x_column) ax.set_ylabel(y_column) ax.set_title(title or f"{y_column} vs {x_column}") ax.grid(True, linestyle='--', alpha=0.7) # Convert to base64 img_str = self._figure_to_base64(fig) return { "chart_type": "scatter", "x_column": x_column, "y_column": y_column, "color_column": color_column, "data_points": len(df), "image": img_str } def create_histogram(self, filename: str, column: str, bins: int = 10, title: str = None) -> Dict[str, Any]: """Create a histogram visualization.""" df = self._load_dataframe(filename) fig, ax = plt.subplots(figsize=self.figure_size) # Create histogram ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black') # Set labels and title ax.set_xlabel(column) ax.set_ylabel('Frequency') ax.set_title(title or f"Distribution of {column}") ax.grid(True, linestyle='--', alpha=0.7) # Convert to base64 img_str = self._figure_to_base64(fig) return { "chart_type": "histogram", "column": column, "bins": bins, "data_points": len(df), "image": img_str } def create_pie_chart(self, filename: str, label_column: str, value_column: str = None, title: str = None, limit: int = 10) -> Dict[str, Any]: """Create a pie chart visualization.""" df = self._load_dataframe(filename) # If value column not provided, count occurrences of each label if value_column is None: data = df[label_column].value_counts().head(limit) labels = data.index.tolist() values = data.values.tolist() else: # Group by label and sum values grouped = df.groupby(label_column)[value_column].sum().reset_index() # Limit to top categories grouped = grouped.nlargest(limit, value_column) labels = grouped[label_column].tolist() values = grouped[value_column].tolist() fig, ax = plt.subplots(figsize=self.figure_size) # Create pie chart ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True) ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle # Set title ax.set_title(title or f"Distribution of {label_column}") # Convert to base64 img_str = self._figure_to_base64(fig) return { "chart_type": "pie", "label_column": label_column, "value_column": value_column, "categories": len(labels), "image": img_str }