Spaces:
Running
Running
import graphviz | |
import json | |
from tempfile import NamedTemporaryFile | |
import os | |
from graph_generator_utils import add_nodes_and_edges | |
def generate_wbs_diagram(json_input: str, output_format: str) -> str: | |
""" | |
Generates a Work Breakdown Structure (WBS) Diagram from JSON input. | |
Args: | |
json_input (str): A JSON string describing the WBS structure. | |
It must follow the Expected JSON Format Example below. | |
output_format (str): The output format for the generated diagram. | |
Supported formats: "png" or "svg" | |
Expected JSON Format Example: | |
{ | |
"project_title": "AI Model Development Project", | |
"phases": [ | |
{ | |
"id": "phase_prep", | |
"label": "Preparation", | |
"tasks": [ | |
{ | |
"id": "task_1_1_vision", | |
"label": "Identify Vision", | |
"subtasks": [ | |
{ | |
"id": "subtask_1_1_1_design_staff", | |
"label": "Design & Staffing", | |
"sub_subtasks": [ | |
{ | |
"id": "ss_task_1_1_1_1_env_setup", | |
"label": "Environment Setup", | |
"sub_sub_subtasks": [ | |
{ | |
"id": "sss_task_1_1_1_1_1_lib_install", | |
"label": "Install Libraries", | |
"final_level_tasks": [ | |
{"id": "ft_1_1_1_1_1_1_data_access", "label": "Grant Data Access"} | |
] | |
} | |
] | |
} | |
] | |
} | |
] | |
} | |
] | |
}, | |
{ | |
"id": "phase_plan", | |
"label": "Planning", | |
"tasks": [ | |
{ | |
"id": "task_2_1_cost_analysis", | |
"label": "Cost Analysis", | |
"subtasks": [ | |
{ | |
"id": "subtask_2_1_1_benefit_analysis", | |
"label": "Benefit Analysis", | |
"sub_subtasks": [ | |
{ | |
"id": "ss_task_2_1_1_1_risk_assess", | |
"label": "AI Risk Assessment", | |
"sub_sub_subtasks": [ | |
{ | |
"id": "sss_task_2_1_1_1_1_model_selection", | |
"label": "Model Selection", | |
"final_level_tasks": [ | |
{"id": "ft_2_1_1_1_1_1_data_strategy", "label": "Data Strategy"} | |
] | |
} | |
] | |
} | |
] | |
} | |
] | |
} | |
] | |
}, | |
{ | |
"id": "phase_dev", | |
"label": "Development", | |
"tasks": [ | |
{ | |
"id": "task_3_1_change_mgmt", | |
"label": "Data Preprocessing", | |
"subtasks": [ | |
{ | |
"id": "subtask_3_1_1_implementation", | |
"label": "Feature Engineering", | |
"sub_subtasks": [ | |
{ | |
"id": "ss_task_3_1_1_1_beta_testing", | |
"label": "Model Training", | |
"sub_sub_subtasks": [ | |
{ | |
"id": "sss_task_3_1_1_1_1_other_task", | |
"label": "Model Evaluation", | |
"final_level_tasks": [ | |
{"id": "ft_3_1_1_1_1_1_hyperparam_tune", "label": "Hyperparameter Tuning"} | |
] | |
} | |
] | |
} | |
] | |
} | |
] | |
} | |
] | |
} | |
] | |
} | |
Returns: | |
str: The filepath to the generated image file. | |
""" | |
try: | |
if not json_input.strip(): | |
return "Error: Empty input" | |
data = json.loads(json_input) | |
if 'project_title' not in data or 'phases' not in data: | |
raise ValueError("Missing required fields: project_title or phases") | |
dot = graphviz.Digraph( | |
name='WBSDiagram', | |
format='png', | |
graph_attr={ | |
'rankdir': 'TB', | |
'splines': 'ortho', | |
'bgcolor': 'white', | |
'pad': '0.5', | |
'ranksep': '0.6', | |
'nodesep': '0.5' | |
} | |
) | |
base_color = '#BEBEBE' | |
dot.node( | |
'project_root', | |
data['project_title'], | |
shape='box', | |
style='filled,rounded', | |
fillcolor=base_color, | |
fontcolor='black', | |
fontsize='18' | |
) | |
def get_gradient_color(depth, base_hex_color, lightening_factor=0.06): | |
if not isinstance(base_hex_color, str) or not base_hex_color.startswith('#') or len(base_hex_color) != 7: | |
base_hex_color = '#BEBEBE' | |
base_r = int(base_hex_color[1:3], 16) | |
base_g = int(base_hex_color[3:5], 16) | |
base_b = int(base_hex_color[5:7], 16) | |
current_r = base_r + int((255 - base_r) * depth * lightening_factor) | |
current_g = base_g + int((255 - base_g) * depth * lightening_factor) | |
current_b = base_b + int((255 - base_b) * depth * lightening_factor) | |
return f'#{min(255, current_r):02x}{min(255, current_g):02x}{min(255, current_b):02x}' | |
def get_font_color_for_background(depth, base_hex_color, lightening_factor=0.06): | |
return 'black' | |
def _add_wbs_nodes_recursive(parent_id, current_level_tasks, current_depth): | |
for task_data in current_level_tasks: | |
task_id = task_data.get('id') | |
task_label = task_data.get('label') | |
if not all([task_id, task_label]): | |
raise ValueError(f"Invalid task data at depth {current_depth}: {task_data}") | |
node_fill_color = get_gradient_color(current_depth, base_color) | |
node_font_color = get_font_color_for_background(current_depth, base_color) | |
font_size = max(9, 14 - (current_depth * 2)) | |
dot.node( | |
task_id, | |
task_label, | |
shape='box', | |
style='filled,rounded', | |
fillcolor=node_fill_color, | |
fontcolor=node_font_color, | |
fontsize=str(font_size) | |
) | |
dot.edge(parent_id, task_id, color='#4a4a4a', arrowhead='none') | |
next_level_keys = ['tasks', 'subtasks', 'sub_subtasks', 'sub_sub_subtasks', 'final_level_tasks'] | |
for key_idx, key in enumerate(next_level_keys): | |
if key in task_data and isinstance(task_data[key], list): | |
_add_wbs_nodes_recursive(task_id, task_data[key], current_depth + 1) | |
break | |
phase_depth = 1 | |
for phase in data['phases']: | |
phase_id = phase.get('id') | |
phase_label = phase.get('label') | |
if not all([phase_id, phase_label]): | |
raise ValueError(f"Invalid phase data: {phase}") | |
phase_fill_color = get_gradient_color(phase_depth, base_color) | |
phase_font_color = get_font_color_for_background(phase_depth, base_color) | |
font_size_phase = max(9, 14 - (phase_depth * 2)) | |
dot.node( | |
phase_id, | |
phase_label, | |
shape='box', | |
style='filled,rounded', | |
fillcolor=phase_fill_color, | |
fontcolor=phase_font_color, | |
fontsize=str(font_size_phase) | |
) | |
dot.edge('project_root', phase_id, color='#4a4a4a', arrowhead='none') | |
if 'tasks' in phase and isinstance(phase['tasks'], list): | |
_add_wbs_nodes_recursive(phase_id, phase['tasks'], phase_depth + 1) | |
with NamedTemporaryFile(delete=False, suffix=f'.{output_format}') as tmp: | |
dot.render(tmp.name, format=output_format, cleanup=True) | |
return f"{tmp.name}.{output_format}" | |
except json.JSONDecodeError: | |
return "Error: Invalid JSON format" | |
except Exception as e: | |
return f"Error: {str(e)}" |