Dallas_Wind / gnn_model.py
jithin14's picture
Added wind prediction models and interface
a71391c
import torch
import torch.nn as nn
import dgl
class WindGNN(nn.Module):
def __init__(self, in_feats=5, hidden_size=128, out_feats=3, num_layers=3):
super(WindGNN, self).__init__()
self.layers = nn.ModuleList()
self.input_proj = nn.Linear(in_feats, hidden_size)
for _ in range(num_layers):
self.layers.append(dgl.nn.GraphConv(hidden_size, hidden_size))
self.velocity_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4),
nn.LayerNorm(hidden_size // 4),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size // 4, 1),
nn.Sigmoid()
)
self.direction_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4),
nn.LayerNorm(hidden_size // 4),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_size // 4, 2)
)
self.layer_norms = nn.ModuleList([
nn.LayerNorm(hidden_size) for _ in range(num_layers)
])
self.dropout = nn.Dropout(0.2)
def forward(self, g, features):
h = self.input_proj(features)
for i, (conv, norm) in enumerate(zip(self.layers, self.layer_norms)):
h_new = conv(g, h)
h_new = norm(h_new)
h_new = nn.functional.relu(h_new)
h_new = self.dropout(h_new)
h = h + h_new if i > 0 else h_new
velocity = self.velocity_head(h)
direction = self.direction_head(h)
direction = nn.functional.normalize(direction, dim=1)
return torch.cat([velocity, direction], dim=1)