Hiring_agent / agentpro /tools /data_tool.py
Shaikmohdhuz's picture
Init
d2542d6
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from io import StringIO
import json
from typing import Dict, List, Optional, Union, Any
import tempfile
from .base import LLMTool
class DataAnalysisTool(LLMTool):
name: str = "Data Analysis Tool"
description: str = "A tool that can analyze data files (CSV, Excel, etc.) and provide insights. It can generate statistics, visualizations, and exploratory data analysis."
arg: str = "Either a file path or a JSON object with parameters for analysis. If providing a path, supply the full path to the data file. If providing parameters, use the format: {'file_path': 'path/to/file', 'analysis_type': 'basic|correlation|visualization', 'columns': ['col1', 'col2'], 'target': 'target_column'}"
# Path to the currently loaded dataframe
_current_file: str = None
_df: Optional[pd.DataFrame] = None
def load_data(self, file_path: str) -> str:
"""Load data from the specified file path."""
try:
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.csv':
self._df = pd.read_csv(file_path)
elif file_ext in ['.xlsx', '.xls']:
self._df = pd.read_excel(file_path)
elif file_ext == '.json':
self._df = pd.read_json(file_path)
elif file_ext == '.parquet':
self._df = pd.read_parquet(file_path)
elif file_ext == '.sql':
# For SQL files, we expect a SQLite database
import sqlite3
conn = sqlite3.connect(file_path)
self._df = pd.read_sql("SELECT * FROM main_table", conn)
conn.close()
else:
return f"Unsupported file format: {file_ext}. Supported formats: .csv, .xlsx, .xls, .json, .parquet, .sql"
self._current_file = file_path
return f"Successfully loaded data from {file_path}. Shape: {self._df.shape}. Columns: {', '.join(self._df.columns.tolist())}"
except Exception as e:
return f"Error loading data: {str(e)}"
def generate_basic_stats(self, columns: Optional[List[str]] = None) -> Dict:
"""Generate basic statistics for the dataframe or specified columns."""
if self._df is None:
return "No data loaded. Please load data first."
try:
if columns:
# Filter to only include columns that exist in the dataframe
valid_columns = [col for col in columns if col in self._df.columns]
if not valid_columns:
return f"None of the specified columns {columns} exist in the dataframe."
df_subset = self._df[valid_columns]
else:
df_subset = self._df
numeric_stats = df_subset.describe().to_dict()
null_counts = df_subset.isnull().sum().to_dict()
categorical_columns = df_subset.select_dtypes(include=['object', 'category']).columns
unique_counts = {col: df_subset[col].nunique() for col in categorical_columns}
stats = {
"shape": self._df.shape,
"columns": self._df.columns.tolist(),
"numeric_stats": numeric_stats,
"null_counts": null_counts,
"unique_counts": unique_counts
}
return stats
except Exception as e:
return f"Error generating basic statistics: {str(e)}"
def generate_correlation_analysis(self, columns: Optional[List[str]] = None) -> Dict:
"""Generate correlation analysis for numeric columns."""
if self._df is None:
return "No data loaded. Please load data first."
try:
numeric_df = self._df.select_dtypes(include=[np.number])
if columns:
# Filter to only include numeric columns that were specified
valid_columns = [col for col in columns if col in numeric_df.columns]
if not valid_columns:
return f"None of the specified columns {columns} are numeric or exist in the dataframe."
numeric_df = numeric_df[valid_columns]
if numeric_df.empty:
return "No numeric columns found in the dataset for correlation analysis."
corr_matrix = numeric_df.corr().to_dict()
corr_df = numeric_df.corr().abs()
upper_tri = corr_df.where(np.triu(np.ones(corr_df.shape), k=1).astype(bool))
high_corr = [(col1, col2, upper_tri.loc[col1, col2])
for col1 in upper_tri.index
for col2 in upper_tri.columns
if upper_tri.loc[col1, col2] > 0.7]
high_corr.sort(key=lambda x: x[2], reverse=True)
return {"correlation_matrix": corr_matrix, "high_correlations": high_corr}
except Exception as e:
return f"Error generating correlation analysis: {str(e)}"
def generate_visualization(self, viz_type: str, columns: Optional[List[str]] = None, target: Optional[str] = None) -> str:
"""Generate visualization based on the specified type and columns."""
if self._df is None:
return "No data loaded. Please load data first."
try:
# Create a temporary directory for the visualization
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
output_path = tmp.name
plt.figure(figsize=(10, 6))
# Handle different visualization types
if viz_type == 'histogram':
if not columns or len(columns) == 0:
# If no columns specified, use all numeric columns
numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist()
if not numeric_cols:
return "No numeric columns found for histogram."
# Limit to 4 columns for readability
columns = numeric_cols[:4]
# Filter to valid columns
valid_columns = [col for col in columns if col in self._df.columns]
if not valid_columns:
return f"None of the specified columns {columns} exist in the dataframe."
for col in valid_columns:
if pd.api.types.is_numeric_dtype(self._df[col]):
plt.hist(self._df[col].dropna(), alpha=0.5, label=col)
plt.legend()
plt.title(f"Histogram of {', '.join(valid_columns)}")
plt.tight_layout()
elif viz_type == 'scatter':
if not columns or len(columns) < 2:
return "Scatter plot requires at least two columns."
# Check if columns exist
if columns[0] not in self._df.columns or columns[1] not in self._df.columns:
return f"One or more of the specified columns {columns[:2]} do not exist in the dataframe."
# Create scatter plot
x_col, y_col = columns[0], columns[1]
plt.scatter(self._df[x_col], self._df[y_col], alpha=0.5)
plt.xlabel(x_col)
plt.ylabel(y_col)
plt.title(f"Scatter Plot: {x_col} vs {y_col}")
# Color by target if provided
if target and target in self._df.columns:
if pd.api.types.is_numeric_dtype(self._df[target]):
scatter = plt.scatter(self._df[x_col], self._df[y_col],
c=self._df[target], alpha=0.5)
plt.colorbar(scatter, label=target)
else:
# For categorical targets, create multiple scatters
categories = self._df[target].unique()
for category in categories:
mask = self._df[target] == category
plt.scatter(self._df.loc[mask, x_col], self._df.loc[mask, y_col], alpha=0.5, label=str(category))
plt.legend()
plt.tight_layout()
elif viz_type == 'correlation':
# Generate correlation heatmap
numeric_df = self._df.select_dtypes(include=[np.number])
if columns:
# Filter to valid numeric columns
valid_columns = [col for col in columns if col in numeric_df.columns]
if not valid_columns:
return f"None of the specified columns {columns} are numeric or exist in the dataframe."
numeric_df = numeric_df[valid_columns]
if numeric_df.empty:
return "No numeric columns found for correlation heatmap."
sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm', linewidths=0.5)
plt.title("Correlation Heatmap")
plt.tight_layout()
elif viz_type == 'boxplot':
if not columns or len(columns) == 0:
# If no columns specified, use all numeric columns
numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist()
if not numeric_cols:
return "No numeric columns found for boxplot."
# Limit to 5 columns for readability
columns = numeric_cols[:5]
# Filter to valid columns
valid_columns = [col for col in columns if col in self._df.columns]
if not valid_columns:
return f"None of the specified columns {columns} exist in the dataframe."
# Create boxplot
self._df[valid_columns].boxplot()
plt.title("Boxplot of Selected Columns")
plt.xticks(rotation=45)
plt.tight_layout()
elif viz_type == 'pairplot':
# Create a pair plot for multiple columns
if not columns or len(columns) < 2:
# Use first 4 numeric columns if not specified
numeric_cols = self._df.select_dtypes(include=[np.number]).columns.tolist()
if len(numeric_cols) < 2:
return "Not enough numeric columns for a pairplot."
columns = numeric_cols[:min(4, len(numeric_cols))]
# Filter to valid columns
valid_columns = [col for col in columns if col in self._df.columns]
if len(valid_columns) < 2:
return f"Not enough valid columns in {columns} for a pairplot."
# Use seaborn pairplot
plt.close() # Close previous figure
# Create a temporary directory for the visualization
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp:
output_path = tmp.name
if target and target in self._df.columns:
g = sns.pairplot(self._df[valid_columns + [target]], hue=target, height=2.5)
else:
g = sns.pairplot(self._df[valid_columns], height=2.5)
plt.suptitle("Pair Plot of Selected Features", y=1.02)
plt.tight_layout()
else:
return f"Unsupported visualization type: {viz_type}. Supported types: histogram, scatter, correlation, boxplot, pairplot"
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()
return f"Visualization saved to: {output_path}"
except Exception as e:
return f"Error generating visualization: {str(e)}"
def generate_data_insights(self) -> str:
"""Generate AI-powered insights about the data."""
if self._df is None:
return "No data loaded. Please load data first."
try:
# Get a sample and info about the data to send to the LLM
df_sample = self._df.head(5).to_string()
df_info = {
"shape": self._df.shape,
"columns": self._df.columns.tolist(),
"dtypes": {col: str(self._df[col].dtype) for col in self._df.columns},
"missing_values": self._df.isnull().sum().to_dict(),
"numeric_stats": self._df.describe().to_dict() if not self._df.select_dtypes(include=[np.number]).empty else {},
}
prompt = f"""
Analyze this dataset and provide key insights.
Dataset Sample:
{df_sample}
Dataset Info:
{json.dumps(df_info, indent=2)}
Your task:
1. Identify the dataset type and potential use cases
2. Summarize the basic characteristics (rows, columns, data types)
3. Highlight key statistics and distributions
4. Point out missing data patterns if any
5. Suggest potential relationships or correlations worth exploring
6. Recommend next steps for deeper analysis
7. Note any data quality issues or anomalies
Provide a comprehensive but concise analysis with actionable insights.
"""
# response = self.client.chat.completions.create(
# model="gpt-4",
# messages=[
# {"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."},
# {"role": "user", "content": prompt}
# ],
# max_tokens=3000)
# return response.choices[0].message.content
openrouter_api_key = os.environ.get("OPENROUTER_API_KEY")
model_name = os.environ.get("MODEL_NAME", "gpt-4") # Default to gpt-4 if MODEL_NAME is not set
try:
if openrouter_api_key:
print(f"Using OpenRouter with model: {model_name} for data insights")
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=openrouter_api_key)
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."},
{"role": "user", "content": prompt}
],
max_tokens=3000)
else: # Fall back to default OpenAI client
print("OpenRouter API key not found, using default OpenAI client with gpt-4")
response = self.client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."},
{"role": "user", "content": prompt}
],
max_tokens=3000)
return response.choices[0].message.content
except Exception as e:
print(f"Error with OpenRouter: {e}")
print("Falling back to default OpenAI client with gpt-4")
try:
response = self.client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "You are a data science expert specializing in exploratory data analysis and deriving insights from datasets."},
{"role": "user", "content": prompt}
],
max_tokens=3000)
return response.choices[0].message.content
except Exception as e2:
return f"Error generating data insights with fallback model: {str(e2)}"
except Exception as e:
return f"Error analyzing data for insights: {str(e)}"
def run(self, prompt: Union[str, Dict]) -> str:
"""Run the data analysis tool."""
print(f"Calling Data Analysis Tool with prompt: {prompt}")
try: # If prompt is a string, try to parse it as JSON or treat it as a file path
if isinstance(prompt, str):
try:
params = json.loads(prompt)
except json.JSONDecodeError: # Treat as file path
return self.load_data(prompt)
else:
params = prompt
# Handle different parameter options
if 'file_path' in params:
file_path = params['file_path']
# Load the data first
load_result = self.load_data(file_path)
if "Successfully" not in load_result:
return load_result
# If no analysis type is specified, generate insights
if 'analysis_type' not in params:
return self.generate_data_insights()
analysis_type = params['analysis_type'].lower()
columns = params.get('columns', None)
target = params.get('target', None)
if analysis_type == 'basic':
stats = self.generate_basic_stats(columns)
return json.dumps(stats, indent=2)
elif analysis_type == 'correlation':
corr_analysis = self.generate_correlation_analysis(columns)
return json.dumps(corr_analysis, indent=2)
elif analysis_type == 'visualization':
viz_type = params.get('viz_type', 'histogram')
return self.generate_visualization(viz_type, columns, target)
elif analysis_type == 'insights':
return self.generate_data_insights()
else:
return f"Unsupported analysis type: {analysis_type}. Supported types: basic, correlation, visualization, insights"
except Exception as e:
return f"Error executing data analysis: {str(e)}"