Spaces:
Sleeping
Sleeping
# main.py | |
# ---------------------------------------------------------------------------- | |
# Import necessary libraries | |
# ---------------------------------------------------------------------------- | |
# pip install gradio numpy pandas matplotlib scipy transformers torch sentencepiece | |
# ---------------------------------------------------------------------------- | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from scipy.stats import norm | |
from transformers import pipeline | |
import warnings | |
import os | |
# Suppress warnings for a cleaner output | |
warnings.filterwarnings("ignore") | |
# Set Matplotlib backend to a non-interactive one to avoid display issues in some environments | |
plt.switch_backend('Agg') | |
# ---------------------------------------------------------------------------- | |
# Global Variables and Initial Setup | |
# ---------------------------------------------------------------------------- | |
# Initialize the Hugging Face pipeline for text generation. | |
# We use a small, efficient model to ensure the app runs smoothly. | |
try: | |
explanation_generator = pipeline('text2text-generation', model='google/flan-t5-small') | |
print("Hugging Face model loaded successfully.") | |
except Exception as e: | |
print(f"Could not load Hugging Face model. Explanations will be disabled. Error: {e}") | |
explanation_generator = None | |
# Create a sample dataset for demonstration purposes. | |
# This simulates the uncertain costs (in thousands of $) for different tasks in a project. | |
sample_project_costs = pd.DataFrame({ | |
'task_cost_thousands': [12, 15, 10, 13, 18, 9, 22, 14, 16, 11, 17, 20] | |
}) | |
SAMPLE_CSV_PATH = 'sample_project_costs.csv' | |
sample_project_costs.to_csv(SAMPLE_CSV_PATH, index=False) | |
# ---------------------------------------------------------------------------- | |
# Core Logic Functions | |
# ---------------------------------------------------------------------------- | |
def create_error_plot(message): | |
"""Creates a matplotlib plot with a specified error message.""" | |
fig, ax = plt.subplots(figsize=(8, 5)) | |
ax.text(0.5, 0.5, message, ha='center', va='center', wrap=True, color='red', fontsize=12) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
plt.tight_layout() | |
return fig | |
def process_input_data(file_obj, example_choice, manual_mean, manual_std): | |
""" | |
Processes the user's input from the UI. | |
It prioritizes input in the order: File Upload > Example Dataset > Manual Entry. | |
It validates the data to ensure it's a single column of numbers. | |
Args: | |
file_obj (File object): The uploaded file from gr.File. | |
example_choice (str): The name of the chosen example dataset. | |
manual_mean (float): Manually entered mean. | |
manual_std (float): Manually entered standard deviation. | |
Returns: | |
tuple: A tuple containing: | |
- A pandas DataFrame with the processed data. | |
- A Matplotlib figure showing the data distribution. | |
- A string with summary statistics. | |
- A string with a validation message. | |
""" | |
data = None | |
source_info = "" | |
# 1. Prioritize input source | |
if file_obj is not None: | |
try: | |
df = pd.read_csv(file_obj.name) | |
source_info = f"from uploaded file: {os.path.basename(file_obj.name)}" | |
data = df | |
except Exception as e: | |
return None, create_error_plot(f"Error reading file: {e}"), None, f"Error reading file: {e}. Please ensure it's a valid CSV." | |
elif example_choice and example_choice == "Project Cost Estimation": | |
df = pd.read_csv(SAMPLE_CSV_PATH) | |
source_info = "from the 'Project Cost Estimation' example" | |
data = df | |
elif manual_mean is not None and manual_std is not None: | |
if manual_std <= 0: | |
return None, create_error_plot("Standard Deviation must be positive."), None, "Manual Input Error: Standard Deviation must be positive." | |
stats_text = (f"Source: Manual Input\n" | |
f"Mean: {manual_mean:.2f}\n" | |
f"Standard Deviation: {manual_std:.2f}") | |
fig, ax = plt.subplots() | |
ax.text(0.5, 0.5, 'Manual input:\nNo data to plot.\nSimulation will use\nthe provided Mean/Std.', | |
ha='center', va='center', fontsize=12) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
plt.tight_layout() | |
manual_df = pd.DataFrame({'mean': [manual_mean], 'std': [manual_std]}) | |
return manual_df, fig, stats_text, "Manual parameters accepted. Ready to run simulation." | |
if data is None: | |
return None, create_error_plot("No data source provided."), None, "No data source provided. Please upload a file, choose an example, or enter parameters." | |
# 2. Validate data structure | |
if data.shape[1] != 1 or not pd.api.types.is_numeric_dtype(data.iloc[:, 0]): | |
error_msg = (f"Data Error: The data {source_info} is not compatible. " | |
"The app requires a CSV with a single column of numerical data. " | |
f"Detected {data.shape[1]} columns.") | |
return None, create_error_plot(error_msg), None, error_msg | |
# 3. Process valid data | |
series = data.iloc[:, 0].dropna() | |
mean = series.mean() | |
std = series.std() | |
if std == 0: | |
error_msg = "Data Error: All values are the same. Standard deviation is zero, cannot simulate uncertainty." | |
return None, create_error_plot(error_msg), None, error_msg | |
# 4. Generate visualization and stats | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
ax.hist(series, bins='auto', density=True, alpha=0.7, label='Input Data Distribution') | |
xmin, xmax = plt.xlim() | |
x = np.linspace(xmin, xmax, 100) | |
p = norm.pdf(x, mean, std) | |
ax.plot(x, p, 'k', linewidth=2, label='Fitted Normal Curve') | |
ax.set_title(f"Distribution of Input Data") | |
ax.set_xlabel(series.name) | |
ax.set_ylabel("Density") | |
ax.legend() | |
ax.grid(True, linestyle='--', alpha=0.6) | |
plt.tight_layout() | |
stats_text = (f"Source: {source_info}\n" | |
f"Number of Data Points: {len(series)}\n" | |
f"Mean: {mean:.2f}\n" | |
f"Standard Deviation: {std:.2f}\n" | |
f"Min: {series.min():.2f}\n" | |
f"Max: {series.max():.2f}") | |
validation_message = "Data loaded and validated successfully! Ready to run the simulation." | |
return data, fig, stats_text, validation_message | |
def run_monte_carlo_simulation(data, num_simulations, target_value): | |
""" | |
Performs the Monte Carlo simulation based on the processed data. | |
""" | |
# **NEW**: Check for valid data at the beginning and return clear error plots if invalid. | |
if data is None: | |
error_message = "ERROR: No valid data available.\nPlease go to Step 1 & 2 and click 'Prepare Simulation' first." | |
error_plot = create_error_plot(error_message) | |
return error_plot, error_plot, "Simulation failed. See plot for details." | |
num_simulations = int(num_simulations) | |
if 'mean' in data.columns and 'std' in data.columns and data.shape[0] == 1: | |
mean = data['mean'].iloc[0] | |
std = data['std'].iloc[0] | |
data_name = "Value" | |
else: | |
series = data.iloc[:, 0] | |
mean = series.mean() | |
std = series.std() | |
data_name = series.name | |
simulation_results = np.random.normal(mean, std, num_simulations) | |
fig_hist, ax_hist = plt.subplots(figsize=(8, 5)) | |
ax_hist.hist(simulation_results, bins=50, density=True, alpha=0.8, color='skyblue', edgecolor='black') | |
sim_mean = np.mean(simulation_results) | |
p5 = np.percentile(simulation_results, 5) | |
p95 = np.percentile(simulation_results, 95) | |
ax_hist.axvline(sim_mean, color='red', linestyle='--', linewidth=2, label=f'Mean: {sim_mean:.2f}') | |
ax_hist.axvline(p5, color='green', linestyle=':', linewidth=2, label=f'5th Percentile (P5): {p5:.2f}') | |
ax_hist.axvline(p95, color='green', linestyle=':', linewidth=2, label=f'95th Percentile (P95): {p95:.2f}') | |
ax_hist.set_title(f'Monte Carlo Simulation Results ({num_simulations:,} Iterations)', fontsize=14) | |
ax_hist.set_xlabel(f'Simulated {data_name}') | |
ax_hist.set_ylabel('Probability Density') | |
ax_hist.legend() | |
ax_hist.grid(True, linestyle='--', alpha=0.6) | |
plt.tight_layout() | |
fig_cdf, ax_cdf = plt.subplots(figsize=(8, 5)) | |
sorted_results = np.sort(simulation_results) | |
yvals = np.arange(len(sorted_results)) / float(len(sorted_results) - 1) | |
ax_cdf.plot(sorted_results, yvals, label='CDF') | |
p50 = np.percentile(simulation_results, 50) | |
ax_cdf.plot(p5, 0.05, 'go', ms=8, label=f'P5: {p5:.2f}') | |
ax_cdf.plot(p50, 0.50, 'ro', ms=8, label=f'Median (P50): {p50:.2f}') | |
ax_cdf.plot(p95, 0.95, 'go', ms=8, label=f'P95: {p95:.2f}') | |
ax_cdf.set_title('Cumulative Distribution Function (CDF)', fontsize=14) | |
ax_cdf.set_xlabel(f'Simulated {data_name}') | |
ax_cdf.set_ylabel('Cumulative Probability') | |
ax_cdf.grid(True, linestyle='--', alpha=0.6) | |
ax_cdf.legend() | |
plt.tight_layout() | |
prob_achieved = 0 | |
if target_value is not None: | |
prob_achieved = np.sum(simulation_results <= target_value) / num_simulations * 100 | |
results_summary = ( | |
f"Simulation Summary ({num_simulations:,} iterations):\n" | |
f"--------------------------------------------------\n" | |
f"Mean (Average Outcome): {sim_mean:.2f}\n" | |
f"Standard Deviation: {np.std(simulation_results):.2f}\n\n" | |
f"Percentiles (Confidence Range):\n" | |
f" - 5th Percentile (P5): {p5:.2f}\n" | |
f" - 50th Percentile (Median): {p50:.2f}\n" | |
f" - 95th Percentile (P95): {p95:.2f}\n" | |
f"This means there is a 90% probability the outcome will be between {p5:.2f} and {p95:.2f}.\n\n" | |
) | |
if target_value is not None: | |
results_summary += ( | |
f"Probability Analysis:\n" | |
f" - Probability of outcome being less than or equal to {target_value:.2f}: {prob_achieved:.2f}%\n" | |
) | |
return fig_hist, fig_cdf, results_summary | |
def generate_explanation(results_summary): | |
""" | |
Uses a Hugging Face model to explain the simulation results in simple terms. | |
""" | |
if explanation_generator is None: | |
return "LLM model not loaded. Cannot generate explanation." | |
# **NEW**: More robust check for failed simulation runs. | |
if not results_summary or "Please process valid data" in results_summary or "Simulation failed" in results_summary: | |
return "Could not generate explanation. Please run a successful simulation first." | |
prompt = f""" | |
Explain the following Monte Carlo simulation results to a non-technical manager. | |
Focus on what the numbers mean in terms of risk and decision-making. Be concise and clear. | |
Results: | |
{results_summary} | |
Explanation: | |
""" | |
try: | |
response = explanation_generator(prompt, max_length=200, num_beams=3, no_repeat_ngram_size=2) | |
return response[0]['generated_text'] | |
except Exception as e: | |
return f"Error generating explanation: {e}" | |
# ---------------------------------------------------------------------------- | |
# Gradio UI Layout | |
# ---------------------------------------------------------------------------- | |
with gr.Blocks(theme=gr.themes.Soft(), title="Monte Carlo Simulation Explorer") as app: | |
gr.Markdown( | |
""" | |
# Welcome to the Monte Carlo Simulation Explorer! | |
This tool helps you understand and perform a Monte Carlo simulation, a powerful technique for modeling uncertainty. | |
**How it works:** Instead of guessing a single outcome, you provide a range of possible inputs (or a distribution). The simulation then runs thousands of trials with random values from that input, creating a probability distribution of all possible outcomes. | |
**Get started:** | |
1. **Provide Data:** Use one of the methods in the "Data Collection" box below. | |
2. **Prepare Simulation:** Click the "Prepare Simulation" button to validate and visualize your input. | |
3. **Run Simulation:** Adjust the settings and click "Run Simulation". | |
4. **Interpret:** Analyze the resulting plots and get an AI-powered explanation. | |
""" | |
) | |
# --- Row 1: Data Input and Preparation --- | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown("### 1. Data Collection") | |
gr.Markdown("Choose **one** method below.") | |
with gr.Tabs(): | |
with gr.TabItem("Upload File"): | |
file_input = gr.File(label="Upload a Single-Column CSV File", file_types=[".csv"]) | |
with gr.TabItem("Use Example"): | |
example_input = gr.Dropdown( | |
["Project Cost Estimation"], label="Select an Example Dataset" | |
) | |
with gr.TabItem("Manual Input"): | |
gr.Markdown("Define a normal distribution manually.") | |
manual_mean_input = gr.Number(label="Mean (Average)") | |
manual_std_input = gr.Number(label="Standard Deviation (Spread)") | |
prepare_button = gr.Button("Prepare Simulation", variant="secondary") | |
with gr.Column(scale=2): | |
with gr.Group(): | |
gr.Markdown("### 2. Preparation & Visualization") | |
validation_output = gr.Textbox(label="Validation Status", interactive=False, lines=3) | |
input_stats_output = gr.Textbox(label="Input Data Statistics", interactive=False, lines=6) | |
input_plot_output = gr.Plot(label="Input Data Distribution") | |
# --- Row 2: Simulation Controls and Results --- | |
with gr.Row(): | |
with gr.Group(): | |
gr.Markdown("### 3. Simulation Run & Results") | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=250): | |
gr.Markdown("**Simulation Settings**") | |
num_simulations_input = gr.Slider( | |
minimum=1000, maximum=50000, value=10000, step=1000, | |
label="Number of Simulations" | |
) | |
target_value_input = gr.Number( | |
label="Target Value (Optional)", | |
info="Calculate the probability of the result being <= this value." | |
) | |
run_button = gr.Button("Run Simulation", variant="primary") | |
with gr.Column(scale=3): | |
with gr.Tabs(): | |
with gr.TabItem("Results Histogram"): | |
results_plot_output = gr.Plot(label="Simulation Outcome Distribution") | |
with gr.TabItem("Cumulative Probability (CDF)"): | |
cdf_plot_output = gr.Plot(label="Cumulative Distribution Function") | |
with gr.TabItem("Numerical Summary"): | |
results_summary_output = gr.Textbox(label="Detailed Results", interactive=False, lines=12) | |
# --- Row 3: AI-Powered Explanation --- | |
with gr.Row(): | |
with gr.Group(): | |
gr.Markdown("### 4. AI-Powered Explanation") | |
explain_button = gr.Button("Explain the Takeaways", variant="secondary") | |
explanation_output = gr.Textbox( | |
label="Key Takeaways from the LLM", | |
interactive=False, | |
lines=5, | |
placeholder="Click the button above to generate an explanation of the results..." | |
) | |
# ---------------------------------------------------------------------------- | |
# Define UI Component Interactions | |
# ---------------------------------------------------------------------------- | |
processed_data_state = gr.State() | |
prepare_button.click( | |
fn=process_input_data, | |
inputs=[file_input, example_input, manual_mean_input, manual_std_input], | |
outputs=[processed_data_state, input_plot_output, input_stats_output, validation_output] | |
) | |
run_button.click( | |
fn=run_monte_carlo_simulation, | |
inputs=[processed_data_state, num_simulations_input, target_value_input], | |
outputs=[results_plot_output, cdf_plot_output, results_summary_output] | |
) | |
explain_button.click( | |
fn=generate_explanation, | |
inputs=[results_summary_output], | |
outputs=[explanation_output] | |
) | |
# ---------------------------------------------------------------------------- | |
# Launch the Gradio App | |
# ---------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
app.launch(debug=True) | |