Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import dgl | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from sklearn.neighbors import KDTree | |
| from gnn_model import WindGNN | |
| from utils import NORM_PARAMS, uv_to_velocity_direction, velocity_direction_to_uv, denormalize_predictions | |
| def plot_errors(box_errors): | |
| """Create plotly visualization with error annotations.""" | |
| fig = go.Figure() | |
| colors = px.colors.qualitative.Set3 | |
| for i, box_error in enumerate(box_errors): | |
| coords = box_error['coords'] | |
| lon_min, lon_max = np.min(coords[:, 0]), np.max(coords[:, 0]) | |
| lat_min, lat_max = np.min(coords[:, 1]), np.max(coords[:, 1]) | |
| center_lon = np.mean([lon_min, lon_max]) | |
| center_lat = np.mean([lat_min, lat_max]) - (lat_max - lat_min) * 0.35 | |
| # Add box boundaries | |
| fig.add_trace(go.Scatter( | |
| x=[lon_min, lon_max, lon_max, lon_min, lon_min], | |
| y=[lat_min, lat_min, lat_max, lat_max, lat_min], | |
| mode='lines', | |
| line=dict(color=colors[i % len(colors)]) | |
| )) | |
| # Add annotations for velocity and direction errors | |
| fig.add_annotation( | |
| x=center_lon, | |
| y=center_lat + 0.02, | |
| text=f"v:{box_error['velocity_error']:.1f}<br>d:{box_error['direction_error']:.1f}°", | |
| showarrow=False, | |
| font=dict(size=8), | |
| align='center' | |
| ) | |
| fig.update_layout( | |
| showlegend=False, | |
| title="Error Distribution by Grid Box", | |
| xaxis_title="Longitude", | |
| yaxis_title="Latitude" | |
| ) | |
| return fig | |
| def inference(input_csv): | |
| """Run inference using the GNN model.""" | |
| try: | |
| # Load the model | |
| model = WindGNN() | |
| checkpoint = torch.load('2.pth', map_location=torch.device('cpu')) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| # Read CSV file | |
| df = pd.read_csv(input_csv.name) | |
| # Extract coordinates and features | |
| coords = df[['x', 'y']].values | |
| # Convert u_30, v_30 to velocity and direction for input | |
| velocity_30, direction_30 = uv_to_velocity_direction(df['u_30'].values, df['v_30'].values) | |
| # Convert true u_45, v_45 to velocity and direction for comparison | |
| true_velocity, true_direction = uv_to_velocity_direction(df['u_45'].values, df['v_45'].values) | |
| # Normalize coordinates and features | |
| norm_x = (coords[:, 0] - NORM_PARAMS['x_min']) / (NORM_PARAMS['x_max'] - NORM_PARAMS['x_min']) | |
| norm_y = (coords[:, 1] - NORM_PARAMS['y_min']) / (NORM_PARAMS['y_max'] - NORM_PARAMS['y_min']) | |
| norm_velocity = (velocity_30 - NORM_PARAMS['v_min']) / (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min']) | |
| # Convert direction to sin and cos components | |
| direction_rad = np.radians(direction_30) | |
| sin_dir = np.sin(direction_rad) | |
| cos_dir = np.cos(direction_rad) | |
| # Create feature matrix | |
| features = np.column_stack([norm_x, norm_y, norm_velocity, sin_dir, cos_dir]) | |
| # Create KD-tree for nearest neighbor search | |
| tree = KDTree(coords) | |
| distances, indices = tree.query(coords, k=9) | |
| # Create graph | |
| src_nodes = [] | |
| dst_nodes = [] | |
| for i in range(len(coords)): | |
| for j in range(1, 9): | |
| neighbor_idx = indices[i][j] | |
| src_nodes.append(i) | |
| dst_nodes.append(neighbor_idx) | |
| g = dgl.graph((torch.tensor(src_nodes), torch.tensor(dst_nodes))) | |
| # Make predictions | |
| with torch.no_grad(): | |
| features_tensor = torch.FloatTensor(features) | |
| predictions = model(g, features_tensor) | |
| # Denormalize predictions | |
| pred_velocity, pred_direction = denormalize_predictions(predictions.numpy()) | |
| # Calculate errors | |
| velocity_errors = np.abs(pred_velocity - true_velocity) | |
| direction_errors = np.abs(pred_direction - true_direction) | |
| # Handle direction errors across 0/360 boundary | |
| direction_errors = np.minimum(direction_errors, 360 - direction_errors) | |
| # Create grid boxes for error visualization | |
| points_per_box = len(coords) // 24 # Ensure 24 boxes | |
| box_errors = [] | |
| for i in range(24): | |
| start_idx = i * points_per_box | |
| end_idx = min((i + 1) * points_per_box, len(coords)) # Ensure we don't go past array bounds | |
| if start_idx < len(coords): # Only create box if we have points left | |
| box_coords = coords[start_idx:end_idx] | |
| box_velocity_error = np.mean(velocity_errors[start_idx:end_idx]) | |
| box_direction_error = np.mean(direction_errors[start_idx:end_idx]) | |
| box_errors.append({ | |
| 'coords': box_coords, | |
| 'velocity_error': box_velocity_error, | |
| 'direction_error': box_direction_error | |
| }) | |
| # Create visualization | |
| fig = plot_errors(box_errors) | |
| # Calculate overall metrics | |
| mae_velocity = np.mean(velocity_errors) | |
| mae_direction = np.mean(direction_errors) | |
| max_velocity_error = np.max(velocity_errors) | |
| min_velocity_error = np.min(velocity_errors) | |
| max_direction_error = np.max(direction_errors) | |
| min_direction_error = np.min(direction_errors) | |
| # Prepare detailed error information | |
| error_info = ( | |
| f"Mean Velocity Error: {mae_velocity:.3f} m/s, " | |
| f"Mean Direction Error: {mae_direction:.3f}°\n" | |
| f"Max Velocity Error: {max_velocity_error:.3f} m/s, " | |
| f"Min Velocity Error: {min_velocity_error:.3f} m/s\n" | |
| f"Max Direction Error: {max_direction_error:.3f}°, " | |
| f"Min Direction Error: {min_direction_error:.3f}°" | |
| ) | |
| # Create results DataFrame | |
| results_df = pd.DataFrame({ | |
| 'x': coords[:, 0], | |
| 'y': coords[:, 1], | |
| 'True Velocity': true_velocity, | |
| 'Predicted Velocity': pred_velocity, | |
| 'True Direction': true_direction, | |
| 'Predicted Direction': pred_direction, | |
| 'Velocity Error': velocity_errors, | |
| 'Direction Error': direction_errors | |
| }) | |
| return fig, error_info, results_df | |
| except Exception as e: | |
| return None, f"Error: {str(e)}", None | |
| if __name__ == "__main__": | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| gr.File(label="Input CSV (with u_30, v_30, u_45, v_45 columns)") | |
| ], | |
| outputs=[ | |
| gr.Plot(label="Error Distribution"), | |
| gr.Textbox(label="Error Metrics"), | |
| gr.DataFrame(label="Detailed Results") | |
| ], | |
| title="Wind Velocity Component Prediction (GNN)", | |
| description="Predict wind velocity components (u_45, v_45) from current measurements (u_30, v_30) using Graph Neural Network." | |
| ) | |
| iface.launch() | |