Dallas_Wind / gnn_inference.py
jithin14's picture
Update gnn_inference.py
74cd686 verified
import json
import numpy as np
import pandas as pd
import torch
import dgl
import gradio as gr
import plotly.graph_objects as go
import plotly.express as px
from sklearn.neighbors import KDTree
from gnn_model import WindGNN
from utils import NORM_PARAMS, uv_to_velocity_direction, velocity_direction_to_uv, denormalize_predictions
def plot_errors(box_errors):
"""Create plotly visualization with error annotations."""
fig = go.Figure()
colors = px.colors.qualitative.Set3
for i, box_error in enumerate(box_errors):
coords = box_error['coords']
lon_min, lon_max = np.min(coords[:, 0]), np.max(coords[:, 0])
lat_min, lat_max = np.min(coords[:, 1]), np.max(coords[:, 1])
center_lon = np.mean([lon_min, lon_max])
center_lat = np.mean([lat_min, lat_max]) - (lat_max - lat_min) * 0.35
# Add box boundaries
fig.add_trace(go.Scatter(
x=[lon_min, lon_max, lon_max, lon_min, lon_min],
y=[lat_min, lat_min, lat_max, lat_max, lat_min],
mode='lines',
line=dict(color=colors[i % len(colors)])
))
# Add annotations for velocity and direction errors
fig.add_annotation(
x=center_lon,
y=center_lat + 0.02,
text=f"v:{box_error['velocity_error']:.1f}<br>d:{box_error['direction_error']:.1f}°",
showarrow=False,
font=dict(size=8),
align='center'
)
fig.update_layout(
showlegend=False,
title="Error Distribution by Grid Box",
xaxis_title="Longitude",
yaxis_title="Latitude"
)
return fig
def inference(input_csv):
"""Run inference using the GNN model."""
try:
# Load the model
model = WindGNN()
checkpoint = torch.load('2.pth', map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Read CSV file
df = pd.read_csv(input_csv.name)
# Extract coordinates and features
coords = df[['x', 'y']].values
# Convert u_30, v_30 to velocity and direction for input
velocity_30, direction_30 = uv_to_velocity_direction(df['u_30'].values, df['v_30'].values)
# Convert true u_45, v_45 to velocity and direction for comparison
true_velocity, true_direction = uv_to_velocity_direction(df['u_45'].values, df['v_45'].values)
# Normalize coordinates and features
norm_x = (coords[:, 0] - NORM_PARAMS['x_min']) / (NORM_PARAMS['x_max'] - NORM_PARAMS['x_min'])
norm_y = (coords[:, 1] - NORM_PARAMS['y_min']) / (NORM_PARAMS['y_max'] - NORM_PARAMS['y_min'])
norm_velocity = (velocity_30 - NORM_PARAMS['v_min']) / (NORM_PARAMS['v_max'] - NORM_PARAMS['v_min'])
# Convert direction to sin and cos components
direction_rad = np.radians(direction_30)
sin_dir = np.sin(direction_rad)
cos_dir = np.cos(direction_rad)
# Create feature matrix
features = np.column_stack([norm_x, norm_y, norm_velocity, sin_dir, cos_dir])
# Create KD-tree for nearest neighbor search
tree = KDTree(coords)
distances, indices = tree.query(coords, k=9)
# Create graph
src_nodes = []
dst_nodes = []
for i in range(len(coords)):
for j in range(1, 9):
neighbor_idx = indices[i][j]
src_nodes.append(i)
dst_nodes.append(neighbor_idx)
g = dgl.graph((torch.tensor(src_nodes), torch.tensor(dst_nodes)))
# Make predictions
with torch.no_grad():
features_tensor = torch.FloatTensor(features)
predictions = model(g, features_tensor)
# Denormalize predictions
pred_velocity, pred_direction = denormalize_predictions(predictions.numpy())
# Calculate errors
velocity_errors = np.abs(pred_velocity - true_velocity)
direction_errors = np.abs(pred_direction - true_direction)
# Handle direction errors across 0/360 boundary
direction_errors = np.minimum(direction_errors, 360 - direction_errors)
# Create grid boxes for error visualization
points_per_box = len(coords) // 24 # Ensure 24 boxes
box_errors = []
for i in range(24):
start_idx = i * points_per_box
end_idx = min((i + 1) * points_per_box, len(coords)) # Ensure we don't go past array bounds
if start_idx < len(coords): # Only create box if we have points left
box_coords = coords[start_idx:end_idx]
box_velocity_error = np.mean(velocity_errors[start_idx:end_idx])
box_direction_error = np.mean(direction_errors[start_idx:end_idx])
box_errors.append({
'coords': box_coords,
'velocity_error': box_velocity_error,
'direction_error': box_direction_error
})
# Create visualization
fig = plot_errors(box_errors)
# Calculate overall metrics
mae_velocity = np.mean(velocity_errors)
mae_direction = np.mean(direction_errors)
max_velocity_error = np.max(velocity_errors)
min_velocity_error = np.min(velocity_errors)
max_direction_error = np.max(direction_errors)
min_direction_error = np.min(direction_errors)
# Prepare detailed error information
error_info = (
f"Mean Velocity Error: {mae_velocity:.3f} m/s, "
f"Mean Direction Error: {mae_direction:.3f}°\n"
f"Max Velocity Error: {max_velocity_error:.3f} m/s, "
f"Min Velocity Error: {min_velocity_error:.3f} m/s\n"
f"Max Direction Error: {max_direction_error:.3f}°, "
f"Min Direction Error: {min_direction_error:.3f}°"
)
# Create results DataFrame
results_df = pd.DataFrame({
'x': coords[:, 0],
'y': coords[:, 1],
'True Velocity': true_velocity,
'Predicted Velocity': pred_velocity,
'True Direction': true_direction,
'Predicted Direction': pred_direction,
'Velocity Error': velocity_errors,
'Direction Error': direction_errors
})
return fig, error_info, results_df
except Exception as e:
return None, f"Error: {str(e)}", None
if __name__ == "__main__":
# Create Gradio interface
iface = gr.Interface(
fn=inference,
inputs=[
gr.File(label="Input CSV (with u_30, v_30, u_45, v_45 columns)")
],
outputs=[
gr.Plot(label="Error Distribution"),
gr.Textbox(label="Error Metrics"),
gr.DataFrame(label="Detailed Results")
],
title="Wind Velocity Component Prediction (GNN)",
description="Predict wind velocity components (u_45, v_45) from current measurements (u_30, v_30) using Graph Neural Network."
)
iface.launch()