maciek-g commited on
Commit
24ade06
1 Parent(s): 07b1bb7

Create particle_transfomer.py

Browse files
Files changed (1) hide show
  1. particle_transfomer.py +175 -0
particle_transfomer.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Embed(nn.Module):
7
+ def __init__(self, input_dim, output_dim, normalize_input=False, event_level=False, activation='gelu'):
8
+ super().__init__()
9
+
10
+ self.input_bn = nn.BatchNorm1d(input_dim) if normalize_input else None
11
+ self.fc1 = nn.Linear(input_dim, output_dim)
12
+ self.fc2 = nn.Linear(output_dim, output_dim)
13
+ self.fc3 = nn.Linear(output_dim, output_dim)
14
+ self.event_level = event_level
15
+
16
+
17
+ def forward(self, x):
18
+ if self.input_bn is not None:
19
+ # x: (batch, embed_dim, seq_len)
20
+ x = self.input_bn(x)
21
+ if not self.event_level:
22
+ x = x.permute(2, 0, 1).contiguous()
23
+
24
+ x = F.relu(self.fc1(x))
25
+ x = F.relu(self.fc2(x))
26
+ x = F.relu(self.fc3(x))
27
+
28
+ return x
29
+
30
+
31
+ class AttBlock(nn.Module):
32
+ def __init__(self, embed_dims, linear_dims1, linear_dims2, num_heads=8, activation='relu'):
33
+ super(AttBlock, self).__init__()
34
+
35
+ self.layer_norm1 = nn.LayerNorm(embed_dims)
36
+ self.multihead_attention = nn.MultiheadAttention(embed_dims, num_heads)
37
+ self.layer_norm2 = nn.LayerNorm(embed_dims)
38
+ self.linear1 = nn.Linear(embed_dims, linear_dims1)
39
+ self.activation = nn.ReLU() if activation == 'relu' else nn.GELU()
40
+ self.layer_norm3 = nn.LayerNorm(linear_dims1)
41
+ self.linear2 = nn.Linear(linear_dims1, linear_dims2)
42
+
43
+ def forward(self, x, padding_mask=None):
44
+ # Layer normalization 1
45
+ x = self.layer_norm1(x)
46
+
47
+ if padding_mask is not None:
48
+ # Assuming mask is 0 for non-padded and 1 for padded elements,
49
+ # convert it to a boolean tensor with `True` for padded locations.
50
+ padding_mask = padding_mask.bool()
51
+
52
+
53
+ x_att, attention = self.multihead_attention(x, x, x, key_padding_mask=padding_mask, need_weights=True, average_attn_weights=True)
54
+
55
+ # Skip connection
56
+ x = x + x_att # Skip connection
57
+ # Layer normalization 2
58
+ x = self.layer_norm2(x)
59
+ # Linear layer and activation
60
+ x_linear1 = self.activation(self.linear1(x))
61
+ # Skip connection for the first linear layer
62
+ x = x + x_linear1
63
+ # Layer normalization 3
64
+ x = self.layer_norm3(x_linear1)
65
+ # Linear layer with specified output dimensions
66
+ x_linear2 = self.linear2(x)
67
+ # Skip connection for the second linear layer
68
+ x = x + x_linear2
69
+ return x, attention
70
+
71
+ class ClassBlock(nn.Module):
72
+ def __init__(self, embed_dims, linear_dims1, linear_dims2, num_heads=8, activation='relu'):
73
+ super(ClassBlock, self).__init__()
74
+
75
+ self.layer_norm1 = nn.LayerNorm(embed_dims)
76
+ self.multihead_attention = nn.MultiheadAttention(embed_dims, num_heads)
77
+ self.layer_norm2 = nn.LayerNorm(embed_dims)
78
+ self.linear1 = nn.Linear(embed_dims, linear_dims1)
79
+ self.activation = nn.ReLU() if activation == 'relu' else nn.GELU()
80
+ self.layer_norm3 = nn.LayerNorm(linear_dims1)
81
+ self.linear2 = nn.Linear(linear_dims1, linear_dims2)
82
+
83
+ def forward(self, x, class_token, padding_mask=None):
84
+ # Concatenate the class token to the input sequence along the sequence length dimension
85
+ x = torch.cat((class_token, x), dim=0) # (seq_len+1, batch, embed_dim)
86
+ # Layer normalization 1
87
+ x = self.layer_norm1(x)
88
+
89
+ # Multihead Attention
90
+ if padding_mask is not None:
91
+ # Ensure mask has the correct shape for attention
92
+ padding_mask = torch.cat((torch.zeros_like(padding_mask[:, :1]), padding_mask), dim=1)
93
+ padding_mask = padding_mask.bool()
94
+
95
+
96
+ x_att, attention = self.multihead_attention(class_token, x, x, key_padding_mask=padding_mask, need_weights=True, average_attn_weights=False)
97
+ # Layer normalization 2
98
+ x = self.layer_norm2(x_att)
99
+ x = class_token + x # Skip connection
100
+ # Linear layer and activation
101
+ x_linear1 = self.activation(self.linear1(x))
102
+ # Layer normalization 3
103
+ x_linear1 = self.layer_norm3(x_linear1)
104
+ # Linear layer with specified output dimensions
105
+ x_linear2 = self.linear2(x_linear1 )
106
+ # Skip connection for the second linear layer
107
+ x = x + x_linear2
108
+ return x, attention
109
+
110
+ class MLPHead(nn.Module):
111
+ def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
112
+ super(MLPHead, self).__init__()
113
+ self.fc1 = nn.Linear(input_dim, hidden_dim1)
114
+ self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
115
+ self.fc3 = nn.Linear(hidden_dim2, output_dim)
116
+
117
+ def forward(self, x):
118
+ x = F.relu(self.fc1(x))
119
+ x = F.relu(self.fc2(x))
120
+ x = self.fc3(x)
121
+ return x
122
+
123
+
124
+ class AnalysisObjectTransformer(nn.Module):
125
+ def __init__(self, input_dim_obj, input_dim_event, embed_dims, linear_dims1, linear_dims2, mlp_hidden_1, mlp_hidden_2, num_heads=8):
126
+ super(AnalysisObjectTransformer, self).__init__()
127
+
128
+ self.embed_dims = embed_dims
129
+
130
+ # Embedding layer (assumed to be external)
131
+ self.embedding_layer = Embed(input_dim_obj, embed_dims)
132
+ self.embedding_layer_event_level = Embed(input_dim_event, embed_dims, event_level=True)
133
+
134
+ # Three blocks of self-attention
135
+ self.block1 = AttBlock(embed_dims, linear_dims1, linear_dims1, num_heads)
136
+ self.block2 = AttBlock(linear_dims1, linear_dims1, linear_dims1, num_heads)
137
+ self.block3 = AttBlock(linear_dims1, linear_dims2, linear_dims2, num_heads)
138
+ self.block5 = ClassBlock(linear_dims2, linear_dims1, linear_dims2, num_heads)
139
+ self.block6 = ClassBlock(linear_dims2, linear_dims1, linear_dims2, num_heads)
140
+ self.block7 = ClassBlock(linear_dims2, linear_dims1, linear_dims2, num_heads)
141
+
142
+ # Output linear layer and sigmoid activation
143
+
144
+ self.mlp = MLPHead(embed_dims + input_dim_event, mlp_hidden_1, mlp_hidden_2, output_dim=1)
145
+ self.sigmoid = nn.Sigmoid()
146
+
147
+ def forward(self, x, event_level, mask=None):
148
+
149
+ cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims), requires_grad=True)
150
+ cls_token = nn.init.trunc_normal_(cls_token, std=.02)
151
+ # Embedding layer
152
+
153
+ x = self.embedding_layer(x)
154
+ x = x.permute(1, 0, 2)
155
+
156
+ attention_weights = []
157
+
158
+ # Three blocks of self-attention
159
+ x, attention = self.block1(x, padding_mask=mask)
160
+
161
+ attention_weights.append(attention)
162
+ x, attention = self.block2(x, padding_mask=mask)
163
+ attention_weights.append(attention)
164
+ x, attention = self.block3(x, padding_mask=mask)
165
+ attention_weights.append(attention)
166
+
167
+ cls_tokens = cls_token.expand(1, x.size(1), -1) # (1, N, C)
168
+ cls_tokens, attention = self.block5(x, cls_tokens, padding_mask=mask)
169
+ cls_tokens, attention = self.block6(x, cls_tokens, padding_mask=mask)
170
+ cls_tokens, attention = self.block7(x, cls_tokens, padding_mask=mask)
171
+
172
+ x = torch.cat((cls_tokens.squeeze(0), event_level), dim=-1)
173
+ x = self.mlp(x)
174
+ output_probabilities = self.sigmoid(x)
175
+ return output_probabilities, attention_weights