import gradio as gr import pandas as pd import seaborn as sns import matplotlib.pyplot as plt import io import chardet from PIL import Image import numpy as np def detect_encoding(file): try: with open(file, 'rb') as f: raw = f.read(10000) # Read a chunk of the file return chardet.detect(raw)['encoding'] except Exception as e: print(f"Error detecting encoding: {str(e)}") return 'utf-8' # Default to UTF-8 if detection fails def create_plots(df, feature_columns, target_column): plots = [] try: # Check if the target column is numeric is_numeric_target = pd.api.types.is_numeric_dtype(df[target_column]) # Determine the number of groups based on the number of feature columns num_groups = 2 if len(feature_columns) > 3 else 1 # Split the features into groups if num_groups == 2: mid = len(feature_columns) // 2 feature_groups = [feature_columns[:mid], feature_columns[mid:]] else: feature_groups = [feature_columns] for group, features in enumerate(feature_groups, 1): # Add target to each feature set features = features + [target_column] # Create scatter plot plt.figure(figsize=(12, 10)) try: if is_numeric_target: scatter_plot = sns.pairplot(df[features], kind='scatter', plot_kws={'alpha': 0.6}, corner=True) norm = plt.Normalize(df[target_column].min(), df[target_column].max()) for ax in scatter_plot.axes.flatten(): if ax.get_xlabel() != ax.get_ylabel() and ax.get_xlabel() is not None: if len(ax.collections) > 0: scatter = ax.collections[0] scatter.set_cmap('viridis') scatter.set_norm(norm) scatter.set_array(df[target_column]) plt.colorbar(scatter, ax=ax, label=target_column) else: scatter_plot = sns.pairplot(df[features], hue=target_column, kind='scatter', corner=True) scatter_plot.fig.suptitle(f'Scatter Plots - Group {group}', y=1.02, fontsize=16) # Adjust label size and spacing for ax in scatter_plot.axes.flatten(): ax.tick_params(labelsize=10) if ax.get_xlabel(): ax.set_xlabel(ax.get_xlabel(), fontsize=12) if ax.get_ylabel(): ax.set_ylabel(ax.get_ylabel(), fontsize=12) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png', dpi=300) buf.seek(0) plots.append(buf) except Exception as e: print(f"Error in scatter plot for group {group}: {str(e)}") finally: plt.close() # Create histogram plot plt.figure(figsize=(12, 10)) try: if is_numeric_target: hist_plot = sns.pairplot(df[features], kind='hist', plot_kws={'alpha': 0.6}, corner=True) for ax in hist_plot.axes.flatten(): if ax.get_xlabel() == ax.get_ylabel() and ax.get_xlabel() is not None: ax.clear() sns.histplot(df[ax.get_xlabel()], ax=ax, kde=True) elif ax.get_xlabel() is not None and ax.get_ylabel() is not None: if len(ax.collections) > 0: scatter = ax.collections[0] scatter.set_cmap('viridis') scatter.set_norm(norm) scatter.set_array(df[target_column]) plt.colorbar(scatter, ax=ax, label=target_column) else: hist_plot = sns.pairplot(df[features], kind='hist', hue=target_column, corner=True) hist_plot.fig.suptitle(f'Histogram Plots - Group {group}', y=1.02, fontsize=16) # Adjust label size and spacing for ax in hist_plot.axes.flatten(): ax.tick_params(labelsize=10) if ax.get_xlabel(): ax.set_xlabel(ax.get_xlabel(), fontsize=12) if ax.get_ylabel(): ax.set_ylabel(ax.get_ylabel(), fontsize=12) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png', dpi=300) buf.seek(0) plots.append(buf) except Exception as e: print(f"Error in histogram plot for group {group}: {str(e)}") finally: plt.close() # Create regression plot n_features = len(features) - 1 # Exclude target column fig, axes = plt.subplots(n_features, n_features, figsize=(16, 14)) fig.suptitle(f'Regression Plots - Group {group}', y=1.02, fontsize=16) try: for i, feature1 in enumerate(features[:-1]): for j, feature2 in enumerate(features[:-1]): if n_features == 1: ax = axes else: ax = axes[i, j] if i != j: if is_numeric_target: scatter = ax.scatter(df[feature1], df[feature2], c=df[target_column], cmap='viridis', alpha=0.6) plt.colorbar(scatter, ax=ax, label=target_column) else: sns.regplot(x=feature1, y=feature2, data=df, ax=ax, scatter_kws={'alpha': 0.6}, line_kws={'color': 'red'}) else: sns.histplot(df[feature1], ax=ax, kde=True) ax.set_xlabel(feature1, fontsize=10) ax.set_ylabel(feature2, fontsize=10) ax.tick_params(labelsize=8) ax.set_title(f'{feature1} vs {feature2}', fontsize=12) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png', dpi=300) buf.seek(0) plots.append(buf) except Exception as e: print(f"Error in regression plot for group {group}: {str(e)}") finally: plt.close() # Calculate Pearson correlation values correlation_matrix = df[feature_columns + [target_column]].corr() # Create a heatmap of Pearson correlation values plt.figure(figsize=(12, 10)) try: heatmap = sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', square=True, cbar_kws={'shrink': .8}) heatmap.set_title('Pearson Correlation Heatmap', fontsize=16) plt.xticks(rotation=45, ha='right', fontsize=10) plt.yticks(fontsize=10) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png', dpi=300) buf.seek(0) plots.append(buf) except Exception as e: print(f"Error in correlation heatmap: {str(e)}") finally: plt.close() except Exception as e: print(f"Error in create_plots: {str(e)}") return plots, num_groups def process_csv(csv_file): try: if csv_file is not None: encoding = detect_encoding(csv_file.name) df = pd.read_csv(csv_file.name, encoding=encoding) return gr.update(choices=df.columns.tolist()), gr.update(choices=df.columns.tolist()) return gr.update(), gr.update() except Exception as e: print(f"Error in process_csv: {str(e)}") return gr.update(), gr.update() def run_analysis(csv_file, feature_columns, target_column): try: if csv_file is None or feature_columns is None or target_column is None: return [None] * 7 encoding = detect_encoding(csv_file.name) df = pd.read_csv(csv_file.name, encoding=encoding) plot_buffers, num_groups = create_plots(df, feature_columns, target_column) # Convert BytesIO objects to PIL Images images = [Image.open(buf) for buf in plot_buffers] if num_groups == 1: # If there's only one group, return 4 images (3 plots + heatmap) while len(images) < 4: images.append(None) return images + [None] * 3 else: # If there are two groups, return 7 images while len(images) < 7: images.append(None) return images except Exception as e: print(f"Error in run_analysis: {str(e)}") return [None] * 7 # Create Gradio interface with gr.Blocks() as iface: gr.Markdown("# Data Analysis Tool") gr.Markdown("Upload a CSV file and select columns to generate plots.") with gr.Row(): csv_file = gr.File(label="Upload CSV file") feature_columns = gr.Dropdown(label="Select Feature Columns", multiselect=True) target_column = gr.Dropdown(label="Select Target Column") csv_file.upload(fn=process_csv, inputs=[csv_file], outputs=[feature_columns, target_column]) analyze_btn = gr.Button("Analyze") with gr.Row(): plot1 = gr.Image(label="Scatter Plots - Group 1") plot4 = gr.Image(label="Scatter Plots - Group 2") with gr.Row(): plot2 = gr.Image(label="Histogram Plots - Group 1") plot5 = gr.Image(label="Histogram Plots - Group 2") with gr.Row(): plot3 = gr.Image(label="Regression Plots - Group 1") plot6 = gr.Image(label="Regression Plots - Group 2") with gr.Row(): heatmap = gr.Image(label="Pearson Correlation Heatmap") analyze_btn.click(fn=run_analysis, inputs=[csv_file, feature_columns, target_column], outputs=[plot1, plot4, plot2, plot5, plot3, plot6, heatmap]) # Launch the app iface.launch()