import json import numpy as np import pandas as pd import torch import dgl import gradio as gr import plotly.graph_objects as go import plotly.express as px from sklearn.neighbors import KDTree from gnn_model import WindGNN from utils import NORM_PARAMS, uv_to_velocity_direction, velocity_direction_to_uv, denormalize_predictions def plot_errors(box_errors): """Create plotly visualization with error annotations.""" fig = go.Figure() colors = px.colors.qualitative.Set3 for i, box_error in enumerate(box_errors): coords = box_error['coords'] lon_min, lon_max = np.min(coords[:, 0]), np.max(coords[:, 0]) lat_min, lat_max = np.min(coords[:, 1]), np.max(coords[:, 1]) center_lon = np.mean([lon_min, lon_max]) center_lat = np.mean([lat_min, lat_max]) - (lat_max - lat_min) * 0.35 # Add box boundaries fig.add_trace(go.Scatter( x=[lon_min, lon_max, lon_max, lon_min, lon_min], y=[lat_min, lat_min, lat_max, lat_max, lat_min], mode='lines', line=dict(color=colors[i % len(colors)]) )) # Add annotations for velocity and direction errors fig.add_annotation( x=center_lon, y=center_lat + 0.02, text=f"v:{box_error['velocity_error']:.1f}
d:{box_error['direction_error']:.1f}°", showarrow=False, font=dict(size=8), align='center' ) fig.update_layout( showlegend=False, title="Error Distribution by Grid Box", xaxis_title="Longitude", yaxis_title="Latitude" ) return fig def inference(input_csv): """Run inference using the GNN model.""" try: # Load the model model = WindGNN() checkpoint = torch.load('2.pth', map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict']) model.eval() # Read CSV file df = pd.read_csv(input_csv.name) # Extract coordinates and features coords = df[['x', 'y']].values # Convert u_30, v_30 to velocity and direction for input velocity_30, direction_30 = uv_to_velocity_direction(df['u_30'].values, df['v_30'].values) # Convert true u_45, v_45 to velocity and direction for comparison true_velocity, true_direction = uv_to_velocity_direction(df['u_45'].values, df['v_45'].values) # Normalize coordinates and features norm_x = (coords[:, 0] - NORM_PARAMS['x_min']) / (NORM_PARAMS['x_max'] - NORM_PARAMS['x_min']) norm_y = (coords[:, 1] - NORM_PARAMS['y_min']) / (NORM_PARAMS['y_max'] - NORM_PARAMS['y_min']) norm_velocity = (velocity_30 - NORM_PARAMS['v_min']) / (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min']) # Convert direction to sin and cos components direction_rad = np.radians(direction_30) sin_dir = np.sin(direction_rad) cos_dir = np.cos(direction_rad) # Create feature matrix features = np.column_stack([norm_x, norm_y, norm_velocity, sin_dir, cos_dir]) # Create KD-tree for nearest neighbor search tree = KDTree(coords) distances, indices = tree.query(coords, k=9) # Create graph src_nodes = [] dst_nodes = [] for i in range(len(coords)): for j in range(1, 9): neighbor_idx = indices[i][j] src_nodes.append(i) dst_nodes.append(neighbor_idx) g = dgl.graph((torch.tensor(src_nodes), torch.tensor(dst_nodes))) # Make predictions with torch.no_grad(): features_tensor = torch.FloatTensor(features) predictions = model(g, features_tensor) # Denormalize predictions pred_velocity, pred_direction = denormalize_predictions(predictions.numpy()) # Calculate errors velocity_errors = np.abs(pred_velocity - true_velocity) direction_errors = np.abs(pred_direction - true_direction) # Handle direction errors across 0/360 boundary direction_errors = np.minimum(direction_errors, 360 - direction_errors) # Create grid boxes for error visualization points_per_box = len(coords) // 24 # Ensure 24 boxes box_errors = [] for i in range(24): start_idx = i * points_per_box end_idx = min((i + 1) * points_per_box, len(coords)) # Ensure we don't go past array bounds if start_idx < len(coords): # Only create box if we have points left box_coords = coords[start_idx:end_idx] box_velocity_error = np.mean(velocity_errors[start_idx:end_idx]) box_direction_error = np.mean(direction_errors[start_idx:end_idx]) box_errors.append({ 'coords': box_coords, 'velocity_error': box_velocity_error, 'direction_error': box_direction_error }) # Create visualization fig = plot_errors(box_errors) # Calculate overall metrics mae_velocity = np.mean(velocity_errors) mae_direction = np.mean(direction_errors) max_velocity_error = np.max(velocity_errors) min_velocity_error = np.min(velocity_errors) max_direction_error = np.max(direction_errors) min_direction_error = np.min(direction_errors) # Prepare detailed error information error_info = ( f"Mean Velocity Error: {mae_velocity:.3f} m/s, " f"Mean Direction Error: {mae_direction:.3f}°\n" f"Max Velocity Error: {max_velocity_error:.3f} m/s, " f"Min Velocity Error: {min_velocity_error:.3f} m/s\n" f"Max Direction Error: {max_direction_error:.3f}°, " f"Min Direction Error: {min_direction_error:.3f}°" ) # Create results DataFrame results_df = pd.DataFrame({ 'x': coords[:, 0], 'y': coords[:, 1], 'True Velocity': true_velocity, 'Predicted Velocity': pred_velocity, 'True Direction': true_direction, 'Predicted Direction': pred_direction, 'Velocity Error': velocity_errors, 'Direction Error': direction_errors }) return fig, error_info, results_df except Exception as e: return None, f"Error: {str(e)}", None if __name__ == "__main__": # Create Gradio interface iface = gr.Interface( fn=inference, inputs=[ gr.File(label="Input CSV (with u_30, v_30, u_45, v_45 columns)") ], outputs=[ gr.Plot(label="Error Distribution"), gr.Textbox(label="Error Metrics"), gr.DataFrame(label="Detailed Results") ], title="Wind Velocity Component Prediction (GNN)", description="Predict wind velocity components (u_45, v_45) from current measurements (u_30, v_30) using Graph Neural Network." ) iface.launch()