from typing import Dict, List, Any, Optional, Tuple, Union import pandas as pd import matplotlib.pyplot as plt import matplotlib import io import base64 import numpy as np from pathlib import Path # Configure matplotlib for non-interactive environments matplotlib.use('Agg') class VisualizationTools: """Tools for creating visualizations from CSV data.""" def __init__(self, csv_directory: str): """Initialize with directory containing CSV files.""" self.csv_directory = csv_directory self.dataframes = {} self.figure_size = (10, 6) self.dpi = 100 def _load_dataframe(self, filename: str) -> pd.DataFrame: """Load a CSV file as DataFrame, with caching.""" if filename not in self.dataframes: file_path = Path(self.csv_directory) / filename if not file_path.exists() and not filename.endswith('.csv'): file_path = Path(self.csv_directory) / f"{filename}.csv" if file_path.exists(): self.dataframes[filename] = pd.read_csv(file_path) else: raise ValueError(f"CSV file not found: {filename}") return self.dataframes[filename] def get_tools(self) -> List[Dict[str, Any]]: """Get all available visualization tools.""" tools = [ { "name": "create_line_chart", "description": "Create a line chart from CSV data", "function": self.create_line_chart }, { "name": "create_bar_chart", "description": "Create a bar chart from CSV data", "function": self.create_bar_chart }, { "name": "create_scatter_plot", "description": "Create a scatter plot from CSV data", "function": self.create_scatter_plot }, { "name": "create_histogram", "description": "Create a histogram from CSV data", "function": self.create_histogram }, { "name": "create_pie_chart", "description": "Create a pie chart from CSV data", "function": self.create_pie_chart } ] return tools def _figure_to_base64(self, fig) -> str: """Convert matplotlib figure to base64 encoded string.""" buf = io.BytesIO() fig.savefig(buf, format='png', dpi=self.dpi) buf.seek(0) img_str = base64.b64encode(buf.read()).decode('utf-8') plt.close(fig) return img_str # Visualization tool implementations def create_line_chart(self, filename: str, x_column: str, y_column: str, title: str = None, limit: int = 50) -> Dict[str, Any]: """Create a line chart visualization.""" df = self._load_dataframe(filename) # Limit data points if needed if len(df) > limit: df = df.head(limit) fig, ax = plt.subplots(figsize=self.figure_size) # Create line chart ax.plot(df[x_column], df[y_column], marker='o', linestyle='-') # Set labels and title ax.set_xlabel(x_column) ax.set_ylabel(y_column) ax.set_title(title or f"{y_column} vs {x_column}") ax.grid(True) # Convert to base64 img_str = self._figure_to_base64(fig) return { "chart_type": "line", "x_column": x_column, "y_column": y_column, "data_points": len(df), "image": img_str } def create_bar_chart(self, filename: str, x_column: str, y_column: str, title: str = None, limit: int = 20) -> Dict[str, Any]: """Create a bar chart visualization.""" df = self._load_dataframe(filename) # Limit categories if needed if len(df) > limit: df = df.head(limit) fig, ax = plt.subplots(figsize=self.figure_size) # Create bar chart ax.bar(df[x_column], df[y_column]) # Set labels and title ax.set_xlabel(x_column) ax.set_ylabel(y_column) ax.set_title(title or f"{y_column} by {x_column}") # Rotate x labels if there are many categories if len(df) > 5: plt.xticks(rotation=45, ha='right') plt.tight_layout() # Convert to base64 img_str = self._figure_to_base64(fig) return { "chart_type": "bar", "x_column": x_column, "y_column": y_column, "categories": len(df), "image": img_str } def create_scatter_plot(self, filename: str, x_column: str, y_column: str, color_column: str = None, title: str = None) -> Dict[str, Any]: """Create a scatter plot visualization.""" df = self._load_dataframe(filename) fig, ax = plt.subplots(figsize=self.figure_size) # Create scatter plot if color_column and color_column in df.columns: scatter = ax.scatter(df[x_column], df[y_column], c=df[color_column], cmap='viridis', alpha=0.7) plt.colorbar(scatter, ax=ax, label=color_column) else: ax.scatter(df[x_column], df[y_column], alpha=0.7) # Set labels and title ax.set_xlabel(x_column) ax.set_ylabel(y_column) ax.set_title(title or f"{y_column} vs {x_column}") ax.grid(True, linestyle='--', alpha=0.7) # Convert to base64 img_str = self._figure_to_base64(fig) return { "chart_type": "scatter", "x_column": x_column, "y_column": y_column, "color_column": color_column, "data_points": len(df), "image": img_str } def create_histogram(self, filename: str, column: str, bins: int = 10, title: str = None) -> Dict[str, Any]: """Create a histogram visualization.""" df = self._load_dataframe(filename) fig, ax = plt.subplots(figsize=self.figure_size) # Create histogram ax.hist(df[column], bins=bins, alpha=0.7, edgecolor='black') # Set labels and title ax.set_xlabel(column) ax.set_ylabel('Frequency') ax.set_title(title or f"Distribution of {column}") ax.grid(True, linestyle='--', alpha=0.7) # Convert to base64 img_str = self._figure_to_base64(fig) return { "chart_type": "histogram", "column": column, "bins": bins, "data_points": len(df), "image": img_str } def create_pie_chart(self, filename: str, label_column: str, value_column: str = None, title: str = None, limit: int = 10) -> Dict[str, Any]: """Create a pie chart visualization.""" df = self._load_dataframe(filename) # If value column not provided, count occurrences of each label if value_column is None: data = df[label_column].value_counts().head(limit) labels = data.index.tolist() values = data.values.tolist() else: # Group by label and sum values grouped = df.groupby(label_column)[value_column].sum().reset_index() # Limit to top categories grouped = grouped.nlargest(limit, value_column) labels = grouped[label_column].tolist() values = grouped[value_column].tolist() fig, ax = plt.subplots(figsize=self.figure_size) # Create pie chart ax.pie(values, labels=labels, autopct='%1.1f%%', startangle=90, shadow=True) ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle # Set title ax.set_title(title or f"Distribution of {label_column}") # Convert to base64 img_str = self._figure_to_base64(fig) return { "chart_type": "pie", "label_column": label_column, "value_column": value_column, "categories": len(labels), "image": img_str }