omarmomen commited on
Commit
3a279f6
1 Parent(s): 2adc920

Update structformer_as_hf.py

Browse files
Files changed (1) hide show
  1. structformer_as_hf.py +412 -11
structformer_as_hf.py CHANGED
@@ -1,3 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
@@ -6,6 +48,13 @@ from transformers import PreTrainedModel
6
  from transformers import PretrainedConfig
7
  from transformers.modeling_outputs import MaskedLMOutput
8
  from typing import List
 
 
 
 
 
 
 
9
 
10
  ##########################################
11
  # HuggingFace Config
@@ -67,7 +116,6 @@ class Conv1d(nn.Module):
67
 
68
  def __init__(self, hidden_size, kernel_size, dilation=1):
69
  """Initialization.
70
-
71
  Args:
72
  hidden_size: dimension of input embeddings
73
  kernel_size: convolution kernel size
@@ -90,7 +138,6 @@ class Conv1d(nn.Module):
90
 
91
  def forward(self, x):
92
  """Compute convolution.
93
-
94
  Args:
95
  x: input embeddings
96
  Returns:
@@ -114,7 +161,6 @@ class MultiheadAttention(nn.Module):
114
  out_proj=True,
115
  relative_bias=True):
116
  """Initialization.
117
-
118
  Args:
119
  embed_dim: dimension of input embeddings
120
  num_heads: number of self-attention heads
@@ -174,7 +220,6 @@ class MultiheadAttention(nn.Module):
174
 
175
  def forward(self, query, key_padding_mask=None, attn_mask=None):
176
  """Compute multi-head self-attention.
177
-
178
  Args:
179
  query: input embeddings
180
  key_padding_mask: 3D mask that prevents attention to certain positions
@@ -254,7 +299,6 @@ class TransformerLayer(nn.Module):
254
  activation="leakyrelu",
255
  relative_bias=True):
256
  """Initialization.
257
-
258
  Args:
259
  d_model: dimension of inputs
260
  nhead: number of self-attention heads
@@ -285,7 +329,6 @@ class TransformerLayer(nn.Module):
285
 
286
  def forward(self, src, attn_mask=None, key_padding_mask=None):
287
  """Pass the input through the encoder layer.
288
-
289
  Args:
290
  src: the sequence to the encoder layer (required).
291
  attn_mask: the mask for the src sequence (optional).
@@ -301,6 +344,30 @@ class TransformerLayer(nn.Module):
301
 
302
  return src3
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  ##########################################
305
  # Custom Models
306
  ##########################################
@@ -362,7 +429,6 @@ class Transformer(nn.Module):
362
  pos_emb=False,
363
  pad=0):
364
  """Initialization.
365
-
366
  Args:
367
  hidden_size: dimension of inputs and hidden states
368
  nlayers: number of layers
@@ -437,7 +503,6 @@ class Transformer(nn.Module):
437
 
438
  def forward(self, x, pos):
439
  """Pass the input through the encoder layer.
440
-
441
  Args:
442
  x: input tokens (required).
443
  pos: position for each token (optional).
@@ -474,7 +539,6 @@ class StructFormer(Transformer):
474
  relations=('head', 'child'),
475
  weight_act='softmax'):
476
  """Initialization.
477
-
478
  Args:
479
  hidden_size: dimension of inputs and hidden states
480
  nlayers: number of layers
@@ -548,7 +612,6 @@ class StructFormer(Transformer):
548
 
549
  def parse(self, x, pos, embeds=None):
550
  """Parse input sentence.
551
-
552
  Args:
553
  x: input tokens (required).
554
  pos: position for each token (optional).
@@ -735,6 +798,300 @@ class StructFormer(Transformer):
735
  attentions=None,
736
  )
737
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
  ##########################################
739
  # HuggingFace Model
740
  ##########################################
@@ -761,4 +1118,48 @@ class StructformerModel(PreTrainedModel):
761
  )
762
 
763
  def forward(self, input_ids, labels=None, **kwargs):
764
- return self.model(input_ids, labels=labels, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hugging Face's logo
2
+ Hugging Face
3
+ Search models, datasets, users...
4
+ Models
5
+ Datasets
6
+ Spaces
7
+ Docs
8
+ Solutions
9
+ Pricing
10
+
11
+
12
+
13
+
14
+ omarmomen
15
+ /
16
+ structformer_s1_final
17
+
18
+ like
19
+ 0
20
+ Fill-Mask
21
+ Transformers
22
+ PyTorch
23
+ structformer
24
+ custom_code
25
+ Model card
26
+ Files and versions
27
+ Community
28
+ Settings
29
+ structformer_s1_final
30
+ /
31
+ structformer_as_hf.py
32
+ Omar
33
+ upfdate
34
+ a7e60f9
35
+ 3 months ago
36
+ raw
37
+ history
38
+ blame
39
+ edit
40
+ delete
41
+ No virus
42
+ 36.1 kB
43
  import torch
44
  import torch.nn as nn
45
  import torch.nn.functional as F
 
48
  from transformers import PretrainedConfig
49
  from transformers.modeling_outputs import MaskedLMOutput
50
  from typing import List
51
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
52
+ from transformers.modeling_outputs import (
53
+ BaseModelOutputWithPastAndCrossAttentions,
54
+ BaseModelOutputWithPoolingAndCrossAttentions,
55
+ MaskedLMOutput,
56
+ SequenceClassifierOutput
57
+ )
58
 
59
  ##########################################
60
  # HuggingFace Config
 
116
 
117
  def __init__(self, hidden_size, kernel_size, dilation=1):
118
  """Initialization.
 
119
  Args:
120
  hidden_size: dimension of input embeddings
121
  kernel_size: convolution kernel size
 
138
 
139
  def forward(self, x):
140
  """Compute convolution.
 
141
  Args:
142
  x: input embeddings
143
  Returns:
 
161
  out_proj=True,
162
  relative_bias=True):
163
  """Initialization.
 
164
  Args:
165
  embed_dim: dimension of input embeddings
166
  num_heads: number of self-attention heads
 
220
 
221
  def forward(self, query, key_padding_mask=None, attn_mask=None):
222
  """Compute multi-head self-attention.
 
223
  Args:
224
  query: input embeddings
225
  key_padding_mask: 3D mask that prevents attention to certain positions
 
299
  activation="leakyrelu",
300
  relative_bias=True):
301
  """Initialization.
 
302
  Args:
303
  d_model: dimension of inputs
304
  nhead: number of self-attention heads
 
329
 
330
  def forward(self, src, attn_mask=None, key_padding_mask=None):
331
  """Pass the input through the encoder layer.
 
332
  Args:
333
  src: the sequence to the encoder layer (required).
334
  attn_mask: the mask for the src sequence (optional).
 
344
 
345
  return src3
346
 
347
+
348
+
349
+ class RobertaClassificationHead(nn.Module):
350
+ """Head for sentence-level classification tasks."""
351
+
352
+ def __init__(self, config):
353
+ super().__init__()
354
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
355
+ classifier_dropout = (
356
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
357
+ )
358
+ self.dropout = nn.Dropout(classifier_dropout)
359
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
360
+
361
+ def forward(self, features, **kwargs):
362
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
363
+ x = self.dropout(x)
364
+ x = self.dense(x)
365
+ x = torch.tanh(x)
366
+ x = self.dropout(x)
367
+ x = self.out_proj(x)
368
+ return x
369
+
370
+
371
  ##########################################
372
  # Custom Models
373
  ##########################################
 
429
  pos_emb=False,
430
  pad=0):
431
  """Initialization.
 
432
  Args:
433
  hidden_size: dimension of inputs and hidden states
434
  nlayers: number of layers
 
503
 
504
  def forward(self, x, pos):
505
  """Pass the input through the encoder layer.
 
506
  Args:
507
  x: input tokens (required).
508
  pos: position for each token (optional).
 
539
  relations=('head', 'child'),
540
  weight_act='softmax'):
541
  """Initialization.
 
542
  Args:
543
  hidden_size: dimension of inputs and hidden states
544
  nlayers: number of layers
 
612
 
613
  def parse(self, x, pos, embeds=None):
614
  """Parse input sentence.
 
615
  Args:
616
  x: input tokens (required).
617
  pos: position for each token (optional).
 
798
  attentions=None,
799
  )
800
 
801
+
802
+
803
+
804
+ class StructFormerClassification(Transformer):
805
+ """StructFormer model."""
806
+
807
+ def __init__(self,
808
+ hidden_size,
809
+ n_context_layers,
810
+ nlayers,
811
+ ntokens,
812
+ nhead=8,
813
+ dropout=0.1,
814
+ dropatt=0.1,
815
+ relative_bias=False,
816
+ pos_emb=False,
817
+ pad=0,
818
+ n_parser_layers=4,
819
+ conv_size=9,
820
+ relations=('head', 'child'),
821
+ weight_act='softmax',
822
+ config=None,
823
+ ):
824
+
825
+
826
+ super(StructFormerClassification, self).__init__(
827
+ hidden_size,
828
+ nlayers,
829
+ ntokens,
830
+ nhead=nhead,
831
+ dropout=dropout,
832
+ dropatt=dropatt,
833
+ relative_bias=relative_bias,
834
+ pos_emb=pos_emb,
835
+ pad=pad)
836
+
837
+ self.num_labels = config.num_labels
838
+ self.config = config
839
+
840
+ if n_context_layers > 0:
841
+ self.context_layers = nn.ModuleList([
842
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
843
+ dropatt=dropatt, relative_bias=relative_bias)
844
+ for _ in range(n_context_layers)])
845
+
846
+ self.parser_layers = nn.ModuleList([
847
+ nn.Sequential(Conv1d(hidden_size, conv_size),
848
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
849
+ nn.Tanh()) for i in range(n_parser_layers)])
850
+
851
+ self.distance_ff = nn.Sequential(
852
+ Conv1d(hidden_size, 2),
853
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
854
+ nn.Linear(hidden_size, 1))
855
+
856
+ self.height_ff = nn.Sequential(
857
+ nn.Linear(hidden_size, hidden_size),
858
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
859
+ nn.Linear(hidden_size, 1))
860
+
861
+ n_rel = len(relations)
862
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
863
+ self._rel_weight.data.normal_(0, 0.1)
864
+
865
+ self._scaler = nn.Parameter(torch.zeros(2))
866
+
867
+ self.n_parse_layers = n_parser_layers
868
+ self.n_context_layers = n_context_layers
869
+ self.weight_act = weight_act
870
+ self.relations = relations
871
+
872
+ self.classifier = RobertaClassificationHead(config)
873
+
874
+ @property
875
+ def scaler(self):
876
+ return self._scaler.exp()
877
+
878
+ @property
879
+ def rel_weight(self):
880
+ if self.weight_act == 'sigmoid':
881
+ return torch.sigmoid(self._rel_weight)
882
+ elif self.weight_act == 'softmax':
883
+ return torch.softmax(self._rel_weight, dim=-1)
884
+
885
+ def parse(self, x, pos, embeds=None):
886
+ """Parse input sentence.
887
+ Args:
888
+ x: input tokens (required).
889
+ pos: position for each token (optional).
890
+ Returns:
891
+ distance: syntactic distance
892
+ height: syntactic height
893
+ """
894
+
895
+ mask = (x != self.pad)
896
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
897
+
898
+
899
+ if embeds is not None:
900
+ h = embeds
901
+ else:
902
+ h = self.emb(x)
903
+
904
+ for i in range(self.n_parse_layers):
905
+ h = h.masked_fill(~mask[:, :, None], 0)
906
+ h = self.parser_layers[i](h)
907
+
908
+ height = self.height_ff(h).squeeze(-1)
909
+ height.masked_fill_(~mask, -1e9)
910
+
911
+ distance = self.distance_ff(h).squeeze(-1)
912
+ distance.masked_fill_(~mask_shifted, 1e9)
913
+
914
+ # Calbrating the distance and height to the same level
915
+ length = distance.size(1)
916
+ height_max = height[:, None, :].expand(-1, length, -1)
917
+ height_max = torch.cummax(
918
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
919
+ dim=-1)[0].triu(0)
920
+
921
+ margin_left = torch.relu(
922
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
923
+ margin_right = torch.relu(distance[:, None, :] - height_max)
924
+ margin = torch.where(margin_left > margin_right, margin_right,
925
+ margin_left).triu(0)
926
+
927
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
928
+ margin.masked_fill_(~margin_mask, 0)
929
+ margin = margin.max()
930
+
931
+ distance = distance - margin
932
+
933
+ return distance, height
934
+
935
+ def compute_block(self, distance, height):
936
+ """Compute constituents from distance and height."""
937
+
938
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
939
+
940
+ gamma = torch.sigmoid(-beta_logits)
941
+ ones = torch.ones_like(gamma)
942
+
943
+ block_mask_left = cummin(
944
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
945
+ block_mask_left = block_mask_left - F.pad(
946
+ block_mask_left[:, :, :-1], (1, 0), value=0)
947
+ block_mask_left.tril_(0)
948
+
949
+ block_mask_right = cummin(
950
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
951
+ block_mask_right = block_mask_right - F.pad(
952
+ block_mask_right[:, :, 1:], (0, 1), value=0)
953
+ block_mask_right.triu_(0)
954
+
955
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
956
+ block = cumsum(block_mask_left).tril(0) + cumsum(
957
+ block_mask_right, reverse=True).triu(1)
958
+
959
+ return block_p, block
960
+
961
+ def compute_head(self, height):
962
+ """Estimate head for each constituent."""
963
+
964
+ _, length = height.size()
965
+ head_logits = height * self.scaler[1]
966
+ index = torch.arange(length, device=height.device)
967
+
968
+ mask = (index[:, None, None] <= index[None, None, :]) * (
969
+ index[None, None, :] <= index[None, :, None])
970
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
971
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
972
+
973
+ head_p = torch.softmax(head_logits, dim=-1)
974
+
975
+ return head_p
976
+
977
+ def generate_mask(self, x, distance, height):
978
+ """Compute head and cibling distribution for each token."""
979
+
980
+ bsz, length = x.size()
981
+
982
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
983
+ eye = eye[None, :, :].expand((bsz, -1, -1))
984
+
985
+ block_p, block = self.compute_block(distance, height)
986
+ head_p = self.compute_head(height)
987
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
988
+ head = head.masked_fill(eye, 0)
989
+ child = head.transpose(1, 2)
990
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
991
+
992
+ rel_list = []
993
+ if 'head' in self.relations:
994
+ rel_list.append(head)
995
+ if 'child' in self.relations:
996
+ rel_list.append(child)
997
+ if 'cibling' in self.relations:
998
+ rel_list.append(cibling)
999
+
1000
+ rel = torch.stack(rel_list, dim=1)
1001
+
1002
+ rel_weight = self.rel_weight
1003
+
1004
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
1005
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
1006
+
1007
+ return att_mask, cibling, head, block
1008
+
1009
+ def encode(self, x, pos, att_mask=None, context_layers=False):
1010
+ """Structformer encoding process."""
1011
+
1012
+ if context_layers:
1013
+ """Standard transformer encode process."""
1014
+ h = self.emb(x)
1015
+ if hasattr(self, 'pos_emb'):
1016
+ h = h + self.pos_emb(pos)
1017
+ h_list = []
1018
+ visibility = self.visibility(x, x.device)
1019
+ for i in range(self.n_context_layers):
1020
+ h_list.append(h)
1021
+ h = self.context_layers[i](
1022
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
1023
+
1024
+ output = h
1025
+ h_array = torch.stack(h_list, dim=2)
1026
+ return output
1027
+
1028
+ else:
1029
+ visibility = self.visibility(x, x.device)
1030
+ h = self.emb(x)
1031
+ if hasattr(self, 'pos_emb'):
1032
+ assert pos.max() < 500
1033
+ h = h + self.pos_emb(pos)
1034
+ for i in range(self.nlayers):
1035
+ h = self.layers[i](
1036
+ h.transpose(0, 1), attn_mask=att_mask[i],
1037
+ key_padding_mask=visibility).transpose(0, 1)
1038
+ return h
1039
+
1040
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
1041
+
1042
+ x = input_ids
1043
+ batch_size, length = x.size()
1044
+
1045
+ if position_ids is None:
1046
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
1047
+
1048
+ context_layers_output = None
1049
+ if self.n_context_layers > 0:
1050
+ context_layers_output = self.encode(x, pos, context_layers=True)
1051
+
1052
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
1053
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
1054
+
1055
+ raw_output = self.encode(x, pos, att_mask)
1056
+ raw_output = self.norm(raw_output)
1057
+ raw_output = self.drop(raw_output)
1058
+
1059
+ #output = self.output_layer(raw_output)
1060
+ logits = self.classifier(raw_output)
1061
+
1062
+ loss = None
1063
+ if labels is not None:
1064
+ if self.config.problem_type is None:
1065
+ if self.num_labels == 1:
1066
+ self.config.problem_type = "regression"
1067
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1068
+ self.config.problem_type = "single_label_classification"
1069
+ else:
1070
+ self.config.problem_type = "multi_label_classification"
1071
+
1072
+ if self.config.problem_type == "regression":
1073
+ loss_fct = MSELoss()
1074
+ if self.num_labels == 1:
1075
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1076
+ else:
1077
+ loss = loss_fct(logits, labels)
1078
+ elif self.config.problem_type == "single_label_classification":
1079
+ loss_fct = CrossEntropyLoss()
1080
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1081
+ elif self.config.problem_type == "multi_label_classification":
1082
+ loss_fct = BCEWithLogitsLoss()
1083
+ loss = loss_fct(logits, labels)
1084
+
1085
+
1086
+ return SequenceClassifierOutput(
1087
+ loss=loss,
1088
+ logits=logits,
1089
+ hidden_states=None,
1090
+ attentions=None,
1091
+ )
1092
+
1093
+
1094
+
1095
  ##########################################
1096
  # HuggingFace Model
1097
  ##########################################
 
1118
  )
1119
 
1120
  def forward(self, input_ids, labels=None, **kwargs):
1121
+ return self.model(input_ids, labels=labels, **kwargs)
1122
+
1123
+
1124
+
1125
+ class StructformerModelForSequenceClassification(PreTrainedModel):
1126
+ config_class = StructformerConfig
1127
+ def __init__(self, config):
1128
+ super().__init__(config)
1129
+ self.model = StructFormerClassification(
1130
+ hidden_size=config.hidden_size,
1131
+ n_context_layers=config.n_context_layers,
1132
+ nlayers=config.nlayers,
1133
+ ntokens=config.ntokens,
1134
+ nhead=config.nhead,
1135
+ dropout=config.dropout,
1136
+ dropatt=config.dropatt,
1137
+ relative_bias=config.relative_bias,
1138
+ pos_emb=config.pos_emb,
1139
+ pad=config.pad,
1140
+ n_parser_layers=config.n_parser_layers,
1141
+ conv_size=config.conv_size,
1142
+ relations=config.relations,
1143
+ weight_act=config.weight_act,
1144
+ config=config)
1145
+
1146
+ def _init_weights(self, module):
1147
+ """Initialize the weights"""
1148
+ if isinstance(module, nn.Linear):
1149
+ # Slightly different from the TF version which uses truncated_normal for initialization
1150
+ # cf https://github.com/pytorch/pytorch/pull/5617
1151
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1152
+ if module.bias is not None:
1153
+ module.bias.data.zero_()
1154
+ elif isinstance(module, nn.Embedding):
1155
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1156
+ if module.padding_idx is not None:
1157
+ module.weight.data[module.padding_idx].zero_()
1158
+ elif isinstance(module, nn.LayerNorm):
1159
+ if module.bias is not None:
1160
+ module.bias.data.zero_()
1161
+ module.weight.data.fill_(1.0)
1162
+
1163
+
1164
+ def forward(self, input_ids, labels=None, **kwargs):
1165
+ return self.model(input_ids, labels=labels, **kwargs)