Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import dgl | |
| class WindGNN(nn.Module): | |
| def __init__(self, in_feats=5, hidden_size=128, out_feats=3, num_layers=3): | |
| super(WindGNN, self).__init__() | |
| self.layers = nn.ModuleList() | |
| self.input_proj = nn.Linear(in_feats, hidden_size) | |
| for _ in range(num_layers): | |
| self.layers.append(dgl.nn.GraphConv(hidden_size, hidden_size)) | |
| self.velocity_head = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 4), | |
| nn.LayerNorm(hidden_size // 4), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(hidden_size // 4, 1), | |
| nn.Sigmoid() | |
| ) | |
| self.direction_head = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size // 4), | |
| nn.LayerNorm(hidden_size // 4), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(hidden_size // 4, 2) | |
| ) | |
| self.layer_norms = nn.ModuleList([ | |
| nn.LayerNorm(hidden_size) for _ in range(num_layers) | |
| ]) | |
| self.dropout = nn.Dropout(0.2) | |
| def forward(self, g, features): | |
| h = self.input_proj(features) | |
| for i, (conv, norm) in enumerate(zip(self.layers, self.layer_norms)): | |
| h_new = conv(g, h) | |
| h_new = norm(h_new) | |
| h_new = nn.functional.relu(h_new) | |
| h_new = self.dropout(h_new) | |
| h = h + h_new if i > 0 else h_new | |
| velocity = self.velocity_head(h) | |
| direction = self.direction_head(h) | |
| direction = nn.functional.normalize(direction, dim=1) | |
| return torch.cat([velocity, direction], dim=1) | |