Spaces:
Runtime error
Runtime error
File size: 11,352 Bytes
6b59850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
import math
import torch
import torch.nn as nn
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
from torch.nn import functional as F
from torch import Tensor
import utils
from diffusion import diffusion_utils
from models.layers import Xtoy, Etoy, masked_softmax
class XEyTransformerLayer(nn.Module):
""" Transformer that updates node, edge and global features
d_x: node features
d_e: edge features
dz : global features
n_head: the number of heads in the multi_head_attention
dim_feedforward: the dimension of the feedforward network model after self-attention
dropout: dropout probablility. 0 to disable
layer_norm_eps: eps value in layer normalizations.
"""
def __init__(self, dx: int, de: int, dy: int, n_head: int, dim_ffX: int = 2048,
dim_ffE: int = 128, dim_ffy: int = 2048, dropout: float = 0.1,
layer_norm_eps: float = 1e-5, device=None, dtype=None) -> None:
kw = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = NodeEdgeBlock(dx, de, dy, n_head, **kw)
self.linX1 = Linear(dx, dim_ffX, **kw)
self.linX2 = Linear(dim_ffX, dx, **kw)
self.normX1 = LayerNorm(dx, eps=layer_norm_eps, **kw)
self.normX2 = LayerNorm(dx, eps=layer_norm_eps, **kw)
self.dropoutX1 = Dropout(dropout)
self.dropoutX2 = Dropout(dropout)
self.dropoutX3 = Dropout(dropout)
self.linE1 = Linear(de, dim_ffE, **kw)
self.linE2 = Linear(dim_ffE, de, **kw)
self.normE1 = LayerNorm(de, eps=layer_norm_eps, **kw)
self.normE2 = LayerNorm(de, eps=layer_norm_eps, **kw)
self.dropoutE1 = Dropout(dropout)
self.dropoutE2 = Dropout(dropout)
self.dropoutE3 = Dropout(dropout)
self.lin_y1 = Linear(dy, dim_ffy, **kw)
self.lin_y2 = Linear(dim_ffy, dy, **kw)
self.norm_y1 = LayerNorm(dy, eps=layer_norm_eps, **kw)
self.norm_y2 = LayerNorm(dy, eps=layer_norm_eps, **kw)
self.dropout_y1 = Dropout(dropout)
self.dropout_y2 = Dropout(dropout)
self.dropout_y3 = Dropout(dropout)
self.activation = F.relu
def forward(self, X: Tensor, E: Tensor, y, node_mask: Tensor):
""" Pass the input through the encoder layer.
X: (bs, n, d)
E: (bs, n, n, d)
y: (bs, dy)
node_mask: (bs, n) Mask for the src keys per batch (optional)
Output: newX, newE, new_y with the same shape.
"""
newX, newE, new_y = self.self_attn(X, E, y, node_mask=node_mask)
newX_d = self.dropoutX1(newX)
X = self.normX1(X + newX_d)
newE_d = self.dropoutE1(newE)
E = self.normE1(E + newE_d)
new_y_d = self.dropout_y1(new_y)
y = self.norm_y1(y + new_y_d)
ff_outputX = self.linX2(self.dropoutX2(self.activation(self.linX1(X))))
ff_outputX = self.dropoutX3(ff_outputX)
X = self.normX2(X + ff_outputX)
ff_outputE = self.linE2(self.dropoutE2(self.activation(self.linE1(E))))
ff_outputE = self.dropoutE3(ff_outputE)
E = self.normE2(E + ff_outputE)
ff_output_y = self.lin_y2(self.dropout_y2(self.activation(self.lin_y1(y))))
ff_output_y = self.dropout_y3(ff_output_y)
y = self.norm_y2(y + ff_output_y)
return X, E, y
class NodeEdgeBlock(nn.Module):
""" Self attention layer that also updates the representations on the edges. """
def __init__(self, dx, de, dy, n_head, **kwargs):
super().__init__()
assert dx % n_head == 0, f"dx: {dx} -- nhead: {n_head}"
self.dx = dx
self.de = de
self.dy = dy
self.df = int(dx / n_head)
self.n_head = n_head
# Attention
self.q = Linear(dx, dx)
self.k = Linear(dx, dx)
self.v = Linear(dx, dx)
# FiLM E to X
self.e_add = Linear(de, dx)
self.e_mul = Linear(de, dx)
# FiLM y to E
self.y_e_mul = Linear(dy, dx) # Warning: here it's dx and not de
self.y_e_add = Linear(dy, dx)
# FiLM y to X
self.y_x_mul = Linear(dy, dx)
self.y_x_add = Linear(dy, dx)
# Process y
self.y_y = Linear(dy, dy)
self.x_y = Xtoy(dx, dy)
self.e_y = Etoy(de, dy)
# Output layers
self.x_out = Linear(dx, dx)
self.e_out = Linear(dx, de)
self.y_out = nn.Sequential(nn.Linear(dy, dy), nn.ReLU(), nn.Linear(dy, dy))
def forward(self, X, E, y, node_mask):
"""
:param X: bs, n, d node features
:param E: bs, n, n, d edge features
:param y: bs, dz global features
:param node_mask: bs, n
:return: newX, newE, new_y with the same shape.
"""
bs, n, _ = X.shape
x_mask = node_mask.unsqueeze(-1) # bs, n, 1
e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
# 1. Map X to keys and queries
Q = self.q(X) * x_mask # (bs, n, dx)
K = self.k(X) * x_mask # (bs, n, dx)
diffusion_utils.assert_correctly_masked(Q, x_mask)
# 2. Reshape to (bs, n, n_head, df) with dx = n_head * df
Q = Q.reshape((Q.size(0), Q.size(1), self.n_head, self.df))
K = K.reshape((K.size(0), K.size(1), self.n_head, self.df))
Q = Q.unsqueeze(2) # (bs, 1, n, n_head, df)
K = K.unsqueeze(1) # (bs, n, 1, n head, df)
# Compute unnormalized attentions. Y is (bs, n, n, n_head, df)
Y = Q * K
Y = Y / math.sqrt(Y.size(-1))
diffusion_utils.assert_correctly_masked(Y, (e_mask1 * e_mask2).unsqueeze(-1))
E1 = self.e_mul(E) * e_mask1 * e_mask2 # bs, n, n, dx
E1 = E1.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df))
E2 = self.e_add(E) * e_mask1 * e_mask2 # bs, n, n, dx
E2 = E2.reshape((E.size(0), E.size(1), E.size(2), self.n_head, self.df))
# Incorporate edge features to the self attention scores.
Y = Y * (E1 + 1) + E2 # (bs, n, n, n_head, df)
# Incorporate y to E
newE = Y.flatten(start_dim=3) # bs, n, n, dx
ye1 = self.y_e_add(y).unsqueeze(1).unsqueeze(1) # bs, 1, 1, de
ye2 = self.y_e_mul(y).unsqueeze(1).unsqueeze(1)
newE = ye1 + (ye2 + 1) * newE
# Output E
newE = self.e_out(newE) * e_mask1 * e_mask2 # bs, n, n, de
diffusion_utils.assert_correctly_masked(newE, e_mask1 * e_mask2)
# Compute attentions. attn is still (bs, n, n, n_head, df)
softmax_mask = e_mask2.expand(-1, n, -1, self.n_head) # bs, 1, n, 1
attn = masked_softmax(Y, softmax_mask, dim=2) # bs, n, n, n_head
V = self.v(X) * x_mask # bs, n, dx
V = V.reshape((V.size(0), V.size(1), self.n_head, self.df))
V = V.unsqueeze(1) # (bs, 1, n, n_head, df)
# Compute weighted values
weighted_V = attn * V
weighted_V = weighted_V.sum(dim=2)
# Send output to input dim
weighted_V = weighted_V.flatten(start_dim=2) # bs, n, dx
# Incorporate y to X
yx1 = self.y_x_add(y).unsqueeze(1)
yx2 = self.y_x_mul(y).unsqueeze(1)
newX = yx1 + (yx2 + 1) * weighted_V
# Output X
newX = self.x_out(newX) * x_mask
diffusion_utils.assert_correctly_masked(newX, x_mask)
# Process y based on X axnd E
y = self.y_y(y)
e_y = self.e_y(E)
x_y = self.x_y(X)
new_y = y + x_y + e_y
new_y = self.y_out(new_y) # bs, dy
return newX, newE, new_y
class GraphTransformer(nn.Module):
"""
n_layers : int -- number of layers
dims : dict -- contains dimensions for each feature type
"""
def __init__(self, n_layers: int, input_dims: dict, cond_dims: int, hidden_mlp_dims: dict, hidden_dims: dict,
output_dims: dict, act_fn_in: nn.ReLU(), act_fn_out: nn.ReLU()):
super().__init__()
self.n_layers = n_layers
self.out_dim_X = output_dims['X']
self.out_dim_E = output_dims['E']
self.out_dim_y = output_dims['y']
self.mlp_in_X = nn.Sequential(nn.Linear(input_dims['X'] + cond_dims, hidden_mlp_dims['X']), act_fn_in,
nn.Linear(hidden_mlp_dims['X'], hidden_dims['dx']), act_fn_in)
self.mlp_in_E = nn.Sequential(nn.Linear(input_dims['E'] + cond_dims, hidden_mlp_dims['E']), act_fn_in,
nn.Linear(hidden_mlp_dims['E'], hidden_dims['de']), act_fn_in)
self.mlp_in_y = nn.Sequential(nn.Linear(input_dims['y'], hidden_mlp_dims['y']), act_fn_in,
nn.Linear(hidden_mlp_dims['y'], hidden_dims['dy']), act_fn_in)
self.tf_layers = nn.ModuleList([XEyTransformerLayer(dx=hidden_dims['dx'],
de=hidden_dims['de'],
dy=hidden_dims['dy'],
n_head=hidden_dims['n_head'],
dim_ffX=hidden_dims['dim_ffX'],
dim_ffE=hidden_dims['dim_ffE'])
for i in range(n_layers)])
self.mlp_out_X = nn.Sequential(nn.Linear(hidden_dims['dx'], hidden_mlp_dims['X']), act_fn_out,
nn.Linear(hidden_mlp_dims['X'], output_dims['X']))
self.mlp_out_E = nn.Sequential(nn.Linear(hidden_dims['de'], hidden_mlp_dims['E']), act_fn_out,
nn.Linear(hidden_mlp_dims['E'], output_dims['E']))
self.mlp_out_y = nn.Sequential(nn.Linear(hidden_dims['dy'], hidden_mlp_dims['y']), act_fn_out,
nn.Linear(hidden_mlp_dims['y'], output_dims['y']))
def forward(self, X, E, y, node_mask):
bs, n = X.shape[0], X.shape[1]
diag_mask = torch.eye(n)
diag_mask = ~diag_mask.type_as(E).bool()
diag_mask = diag_mask.unsqueeze(0).unsqueeze(-1).expand(bs, -1, -1, -1)
X_to_out = X[..., :self.out_dim_X]
E_to_out = E[..., :self.out_dim_E]
y_to_out = y[..., :self.out_dim_y]
new_E = self.mlp_in_E(E)
new_E = (new_E + new_E.transpose(1, 2)) / 2
after_in = utils.PlaceHolder(X=self.mlp_in_X(X), E=new_E, y=self.mlp_in_y(y)).mask(node_mask)
X, E, y = after_in.X, after_in.E, after_in.y
for layer in self.tf_layers:
X, E, y = layer(X, E, y, node_mask)
X = self.mlp_out_X(X)
E = self.mlp_out_E(E)
y = self.mlp_out_y(y)
X = (X + X_to_out)
E = (E + E_to_out) * diag_mask
y = y + y_to_out
E = 1/2 * (E + torch.transpose(E, 1, 2))
return utils.PlaceHolder(X=X, E=E, y=y).mask(node_mask)
|