Spaces:
Runtime error
Runtime error
| import json | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| import os | |
| from pprint import pprint | |
| import uuid | |
| import argparse | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| def create_graph_visualization(json_path: str, output_dir: str, base_name: str, save_plot: bool = True) -> dict: | |
| """Create graph visualization using actual coordinates from bboxes""" | |
| try: | |
| # Remove '_aggregated' suffix if present | |
| if base_name.endswith('_aggregated'): | |
| base_name = base_name[:-len('_aggregated')] | |
| print("\nLoading JSON data...") | |
| with open(json_path, 'r') as f: | |
| data = json.load(f) | |
| # Create graph | |
| G = nx.Graph() | |
| pos = {} | |
| valid_nodes = [] | |
| invalid_nodes = [] | |
| # First pass - collect valid nodes | |
| print("\nValidating nodes...") | |
| for node in tqdm(data.get('nodes', []), desc="Validating"): | |
| try: | |
| node_id = str(node.get('id', '')) | |
| x = float(node.get('x', 0)) | |
| y = float(node.get('y', 0)) | |
| if node_id and x and y: # Only add if we have valid coordinates | |
| valid_nodes.append(node) | |
| pos[node_id] = (x, y) | |
| else: | |
| invalid_nodes.append(node) | |
| except (ValueError, TypeError) as e: | |
| invalid_nodes.append(node) | |
| continue | |
| print(f"\nFound {len(valid_nodes)} valid nodes and {len(invalid_nodes)} invalid nodes") | |
| # Add valid nodes | |
| print("\nAdding valid nodes...") | |
| for node in tqdm(valid_nodes, desc="Nodes"): | |
| node_id = str(node.get('id', '')) | |
| attrs = { | |
| 'type': node.get('type', ''), | |
| 'label': node.get('label', ''), | |
| 'x': float(node.get('x', 0)), | |
| 'y': float(node.get('y', 0)) | |
| } | |
| G.add_node(node_id, **attrs) | |
| # Add valid edges (only between valid nodes) | |
| print("\nAdding valid edges...") | |
| valid_edges = [] | |
| invalid_edges = [] | |
| for edge in tqdm(data.get('edges', []), desc="Edges"): | |
| try: | |
| start_id = str(edge.get('start_point', '')) | |
| end_id = str(edge.get('end_point', '')) | |
| if start_id in pos and end_id in pos: # Only add if both nodes exist | |
| valid_edges.append(edge) | |
| attrs = { | |
| 'type': edge.get('type', ''), | |
| 'weight': edge.get('weight', 1.0) | |
| } | |
| G.add_edge(start_id, end_id, **attrs) | |
| else: | |
| invalid_edges.append(edge) | |
| except Exception as e: | |
| invalid_edges.append(edge) | |
| continue | |
| print(f"\nFound {len(valid_edges)} valid edges and {len(invalid_edges)} invalid edges") | |
| if save_plot: | |
| print("\nGenerating visualization...") | |
| plt.figure(figsize=(20, 20)) | |
| print("Drawing graph elements...") | |
| with tqdm(total=3, desc="Drawing") as pbar: | |
| # Draw nodes | |
| nx.draw_networkx_nodes(G, pos, | |
| node_color='lightblue', | |
| node_size=100) | |
| pbar.update(1) | |
| # Draw edges | |
| nx.draw_networkx_edges(G, pos) | |
| pbar.update(1) | |
| # Save plot | |
| image_path = os.path.join(output_dir, f"{base_name}_graph_visualization.png") | |
| plt.savefig(image_path, bbox_inches='tight', dpi=300) | |
| plt.close() | |
| pbar.update(1) | |
| print(f"\nVisualization saved to: {image_path}") | |
| return { | |
| 'success': True, | |
| 'image_path': image_path, | |
| 'graph': G, | |
| 'stats': { | |
| 'valid_nodes': len(valid_nodes), | |
| 'invalid_nodes': len(invalid_nodes), | |
| 'valid_edges': len(valid_edges), | |
| 'invalid_edges': len(invalid_edges) | |
| } | |
| } | |
| return { | |
| 'success': True, | |
| 'graph': G | |
| } | |
| except Exception as e: | |
| print(f"\nError creating graph: {str(e)}") | |
| return { | |
| 'success': False, | |
| 'error': str(e) | |
| } | |
| if __name__ == "__main__": | |
| """Test the graph visualization independently""" | |
| # Set up argument parser | |
| parser = argparse.ArgumentParser(description='Create and visualize graph from aggregated JSON') | |
| parser.add_argument('--json_path', type=str, default="results/002_page_1_aggregated.json", | |
| help='Path to aggregated JSON file') | |
| parser.add_argument('--output_dir', type=str, default="results", | |
| help='Directory to save outputs') | |
| parser.add_argument('--show', action='store_true', | |
| help='Show the plot interactively') | |
| args = parser.parse_args() | |
| # Verify input file exists | |
| if not os.path.exists(args.json_path): | |
| print(f"Error: Could not find input file {args.json_path}") | |
| exit(1) | |
| # Create output directory if it doesn't exist | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Get base name from input file and remove '_aggregated' suffix | |
| base_name = Path(args.json_path).stem | |
| if base_name.endswith('_aggregated'): | |
| base_name = base_name[:-len('_aggregated')] | |
| print(f"\nProcessing:") | |
| print(f"Input: {args.json_path}") | |
| print(f"Output: {args.output_dir}/{base_name}_graph_visualization.png") | |
| try: | |
| # Create visualization | |
| result = create_graph_visualization( | |
| json_path=args.json_path, | |
| output_dir=args.output_dir, | |
| base_name=base_name, | |
| save_plot=True | |
| ) | |
| if result['success']: | |
| print(f"\nSuccess! Graph visualization saved to: {result['image_path']}") | |
| if args.show: | |
| plt.show() | |
| else: | |
| print(f"\nError: {result['error']}") | |
| except Exception as e: | |
| print(f"\nError during visualization: {str(e)}") | |
| raise |