ChatCSV / tools /data_tools_bk.py
Chamin09's picture
Rename tools/data_tools.py to tools/data_tools_bk.py
10ee83d verified
from typing import Dict, List, Any, Optional, Callable
import pandas as pd
import numpy as np
from llama_index.tools import FunctionTool
from pathlib import Path
class PandasDataTools:
"""Tools for data analysis operations on CSV files."""
def __init__(self, csv_directory: str):
"""Initialize with directory containing CSV files."""
self.csv_directory = csv_directory
self.dataframes = {}
self.tools = self._create_tools()
def _load_dataframe(self, filename: str) -> pd.DataFrame:
"""Load a CSV file as DataFrame, with caching."""
if filename not in self.dataframes:
file_path = Path(self.csv_directory) / filename
if not file_path.exists() and not filename.endswith('.csv'):
file_path = Path(self.csv_directory) / f"{filename}.csv"
if file_path.exists():
self.dataframes[filename] = pd.read_csv(file_path)
else:
raise ValueError(f"CSV file not found: {filename}")
return self.dataframes[filename]
def _create_tools(self) -> List[FunctionTool]:
"""Create LlamaIndex function tools for data operations."""
tools = [
FunctionTool.from_defaults(
name="describe_csv",
description="Get statistical description of a CSV file",
fn=self.describe_csv
),
FunctionTool.from_defaults(
name="filter_data",
description="Filter CSV data based on conditions",
fn=self.filter_data
),
FunctionTool.from_defaults(
name="group_and_aggregate",
description="Group data and calculate aggregate statistics",
fn=self.group_and_aggregate
),
FunctionTool.from_defaults(
name="sort_data",
description="Sort data by specified columns",
fn=self.sort_data
),
FunctionTool.from_defaults(
name="calculate_correlation",
description="Calculate correlation between columns",
fn=self.calculate_correlation
)
]
return tools
def get_tools(self) -> List[FunctionTool]:
"""Get all available data tools."""
return self.tools
# Tool implementations
def describe_csv(self, filename: str) -> Dict[str, Any]:
"""Get statistical description of CSV data."""
df = self._load_dataframe(filename)
description = df.describe().to_dict()
# Add additional info
result = {
"statistics": description,
"shape": df.shape,
"columns": df.columns.tolist(),
"dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()}
}
return result
def filter_data(self, filename: str, column: str, condition: str, value: Any) -> Dict[str, Any]:
"""Filter data based on condition (==, >, <, >=, <=, !=, contains)."""
df = self._load_dataframe(filename)
if condition == "==":
filtered = df[df[column] == value]
elif condition == ">":
filtered = df[df[column] > float(value)]
elif condition == "<":
filtered = df[df[column] < float(value)]
elif condition == ">=":
filtered = df[df[column] >= float(value)]
elif condition == "<=":
filtered = df[df[column] <= float(value)]
elif condition == "!=":
filtered = df[df[column] != value]
elif condition.lower() == "contains":
filtered = df[df[column].astype(str).str.contains(str(value))]
else:
return {"error": f"Unsupported condition: {condition}"}
return {
"result_count": len(filtered),
"results": filtered.head(10).to_dict(orient="records"),
"total_count": len(df)
}
def group_and_aggregate(self, filename: str, group_by: str, agg_column: str,
agg_function: str = "mean") -> Dict[str, Any]:
"""Group by column and calculate aggregate statistic."""
df = self._load_dataframe(filename)
agg_functions = {
"mean": np.mean,
"sum": np.sum,
"min": np.min,
"max": np.max,
"count": len,
"median": np.median
}
if agg_function not in agg_functions:
return {"error": f"Unsupported aggregation function: {agg_function}"}
grouped = df.groupby(group_by)[agg_column].agg(agg_functions[agg_function])
return {
"group_by": group_by,
"aggregated_column": agg_column,
"aggregation": agg_function,
"results": grouped.to_dict()
}
def sort_data(self, filename: str, sort_by: str, ascending: bool = True) -> Dict[str, Any]:
"""Sort data by column."""
df = self._load_dataframe(filename)
sorted_df = df.sort_values(by=sort_by, ascending=ascending)
return {
"sorted_by": sort_by,
"ascending": ascending,
"results": sorted_df.head(10).to_dict(orient="records")
}
def calculate_correlation(self, filename: str, column1: str, column2: str) -> Dict[str, Any]:
"""Calculate correlation between two columns."""
df = self._load_dataframe(filename)
try:
correlation = df[column1].corr(df[column2])
return {
"correlation": correlation,
"column1": column1,
"column2": column2
}
except Exception as e:
return {"error": f"Could not calculate correlation: {str(e)}"}