#!/usr/bin/env python3 """ Visualization script using PandasAI. This script creates a sample dataframe and uses PandasAI to generate and save visualizations based on user queries. Usage: python visualize.py "Create a bar chart of sales by region" Requirements: - pandas - pandasai - matplotlib """ import os import sys import pandas as pd import matplotlib.pyplot as plt import pandasai as pai from dotenv import load_dotenv def create_sample_dataframe(): """Create a sample dataframe with sales data.""" # 'Region': ['North', 'South', 'East', 'West', 'North', 'South', 'East', 'West'], # 'Product': ['Widget', 'Widget', 'Widget', 'Widget', 'Gadget', 'Gadget', 'Gadget', 'Gadget'], # 'Sales': [150, 200, 120, 180, 90, 110, 95, 130], # 'Quarter': ['Q1', 'Q1', 'Q1', 'Q1', 'Q2', 'Q2', 'Q2', 'Q2'], data = { 'Year': [2023, 2023, 2023, 2023, 2023, 2023, 2023, 2023] } return pai.DataFrame(data) def visualize_data(df, query): """ Generate visualization based on user query using PandasAI. Args: df: Pandas DataFrame containing the data query: User query string describing the desired visualization Returns: Path to the saved visualization file """ # Initialize PandasAI with an LLM # Note: In a real application, you would need to set up your OpenAI API key # Either set OPENAI_API_KEY environment variable or pass it directly try: # llm = OpenAI(api_token=api_key) # pandas_ai = PandasAI(llm) load_dotenv() pai.api_key.set(os.environ["PANDAS_KEY"]) df.chat(query) # Generate the visualization print(f"Generating visualization for query: '{query}'") # Save the current figure output_file = "visualization_output.png" plt.savefig(output_file) plt.close() print(f"Visualization saved to {output_file}") return output_file except Exception as e: print(f"Error generating visualization: {str(e)}") return None def main(): """Main function to run the visualization script.""" # Get query from command line argument # if len(sys.argv) < 2: # print("Usage: python visualize.py \"Your visualization query here\"") # print("Example: python visualize.py \"Create a bar chart of sales by region\"") # return # query = sys.argv[1] query = "Plot a bar chart of sales by region" # Create sample dataframe df = create_sample_dataframe() print("Sample DataFrame created:") print(df.head()) # Generate and save visualization output_file = visualize_data(df, query) if output_file: print(f"Visualization process completed. Output saved to: {output_file}") else: print("Visualization process failed.") if __name__ == "__main__": main()