Spaces:
Running
Running
from copy import deepcopy | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from spiga.models.gnn.layers import MLP | |
class GAT(nn.Module): | |
def __init__(self, input_dim: int, output_dim: int, num_heads=4): | |
super().__init__() | |
num_heads_in = num_heads | |
self.reshape = None | |
if input_dim != output_dim: | |
for num_heads_in in range(num_heads, 0, -1): | |
if input_dim % num_heads_in == 0: | |
break | |
self.reshape = MLP([input_dim, output_dim]) | |
self.attention = MessagePassing(input_dim, num_heads_in, out_dim=output_dim) | |
def forward(self, features): | |
message, prob = self.attention(features) | |
if self.reshape: | |
features = self.reshape(features) | |
output = features + message | |
return output, prob | |
class MessagePassing(nn.Module): | |
def __init__(self, feature_dim: int, num_heads: int, out_dim=None): | |
super().__init__() | |
self.attn = Attention(num_heads, feature_dim) | |
self.mlp = MLP([feature_dim*2, feature_dim*2, out_dim]) | |
def forward(self, features): | |
message, prob = self.attn(features, features, features) | |
return self.mlp(torch.cat([features, message], dim=1)), prob | |
class Attention(nn.Module): | |
def __init__(self, num_heads: int, feature_dim: int): | |
super().__init__() | |
assert feature_dim % num_heads == 0 | |
self.dim = feature_dim // num_heads | |
self.num_heads = num_heads | |
self.merge = nn.Conv1d(feature_dim, feature_dim, kernel_size=1) | |
self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) | |
def forward(self, query, key, value): | |
batch_dim = query.size(0) | |
query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) | |
for l, x in zip(self.proj, (query, key, value))] | |
x, prob = self.attention(query, key, value) | |
return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)), prob | |
def attention(self, query, key, value): | |
dim = query.shape[1] | |
scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5 | |
prob = F.softmax(scores, dim=-1) | |
return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob | |