import torch import torch.nn as nn import dgl class WindGNN(nn.Module): def __init__(self, in_feats=5, hidden_size=128, out_feats=3, num_layers=3): super(WindGNN, self).__init__() self.layers = nn.ModuleList() self.input_proj = nn.Linear(in_feats, hidden_size) for _ in range(num_layers): self.layers.append(dgl.nn.GraphConv(hidden_size, hidden_size)) self.velocity_head = nn.Sequential( nn.Linear(hidden_size, hidden_size // 4), nn.LayerNorm(hidden_size // 4), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_size // 4, 1), nn.Sigmoid() ) self.direction_head = nn.Sequential( nn.Linear(hidden_size, hidden_size // 4), nn.LayerNorm(hidden_size // 4), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_size // 4, 2) ) self.layer_norms = nn.ModuleList([ nn.LayerNorm(hidden_size) for _ in range(num_layers) ]) self.dropout = nn.Dropout(0.2) def forward(self, g, features): h = self.input_proj(features) for i, (conv, norm) in enumerate(zip(self.layers, self.layer_norms)): h_new = conv(g, h) h_new = norm(h_new) h_new = nn.functional.relu(h_new) h_new = self.dropout(h_new) h = h + h_new if i > 0 else h_new velocity = self.velocity_head(h) direction = self.direction_head(h) direction = nn.functional.normalize(direction, dim=1) return torch.cat([velocity, direction], dim=1)