omarmomen commited on
Commit
15b31ae
1 Parent(s): 6e35671

Update structformer_as_hf.py

Browse files
Files changed (1) hide show
  1. structformer_as_hf.py +369 -10
structformer_as_hf.py CHANGED
@@ -6,6 +6,13 @@ from transformers import PreTrainedModel
6
  from transformers import PretrainedConfig
7
  from transformers.modeling_outputs import MaskedLMOutput
8
  from typing import List
 
 
 
 
 
 
 
9
 
10
  ##########################################
11
  # HuggingFace Config
@@ -67,7 +74,6 @@ class Conv1d(nn.Module):
67
 
68
  def __init__(self, hidden_size, kernel_size, dilation=1):
69
  """Initialization.
70
-
71
  Args:
72
  hidden_size: dimension of input embeddings
73
  kernel_size: convolution kernel size
@@ -90,7 +96,6 @@ class Conv1d(nn.Module):
90
 
91
  def forward(self, x):
92
  """Compute convolution.
93
-
94
  Args:
95
  x: input embeddings
96
  Returns:
@@ -114,7 +119,6 @@ class MultiheadAttention(nn.Module):
114
  out_proj=True,
115
  relative_bias=True):
116
  """Initialization.
117
-
118
  Args:
119
  embed_dim: dimension of input embeddings
120
  num_heads: number of self-attention heads
@@ -174,7 +178,6 @@ class MultiheadAttention(nn.Module):
174
 
175
  def forward(self, query, key_padding_mask=None, attn_mask=None):
176
  """Compute multi-head self-attention.
177
-
178
  Args:
179
  query: input embeddings
180
  key_padding_mask: 3D mask that prevents attention to certain positions
@@ -254,7 +257,6 @@ class TransformerLayer(nn.Module):
254
  activation="leakyrelu",
255
  relative_bias=True):
256
  """Initialization.
257
-
258
  Args:
259
  d_model: dimension of inputs
260
  nhead: number of self-attention heads
@@ -285,7 +287,6 @@ class TransformerLayer(nn.Module):
285
 
286
  def forward(self, src, attn_mask=None, key_padding_mask=None):
287
  """Pass the input through the encoder layer.
288
-
289
  Args:
290
  src: the sequence to the encoder layer (required).
291
  attn_mask: the mask for the src sequence (optional).
@@ -301,6 +302,30 @@ class TransformerLayer(nn.Module):
301
 
302
  return src3
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  ##########################################
305
  # Custom Models
306
  ##########################################
@@ -362,7 +387,6 @@ class Transformer(nn.Module):
362
  pos_emb=False,
363
  pad=0):
364
  """Initialization.
365
-
366
  Args:
367
  hidden_size: dimension of inputs and hidden states
368
  nlayers: number of layers
@@ -437,7 +461,6 @@ class Transformer(nn.Module):
437
 
438
  def forward(self, x, pos):
439
  """Pass the input through the encoder layer.
440
-
441
  Args:
442
  x: input tokens (required).
443
  pos: position for each token (optional).
@@ -474,7 +497,6 @@ class StructFormer(Transformer):
474
  relations=('head', 'child'),
475
  weight_act='softmax'):
476
  """Initialization.
477
-
478
  Args:
479
  hidden_size: dimension of inputs and hidden states
480
  nlayers: number of layers
@@ -548,7 +570,6 @@ class StructFormer(Transformer):
548
 
549
  def parse(self, x, pos, embeds=None):
550
  """Parse input sentence.
551
-
552
  Args:
553
  x: input tokens (required).
554
  pos: position for each token (optional).
@@ -735,6 +756,300 @@ class StructFormer(Transformer):
735
  attentions=None,
736
  )
737
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
  ##########################################
739
  # HuggingFace Model
740
  ##########################################
@@ -760,5 +1075,49 @@ class StructformerModel(PreTrainedModel):
760
  weight_act=config.weight_act
761
  )
762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
  def forward(self, input_ids, labels=None, **kwargs):
764
  return self.model(input_ids, labels=labels, **kwargs)
 
6
  from transformers import PretrainedConfig
7
  from transformers.modeling_outputs import MaskedLMOutput
8
  from typing import List
9
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ BaseModelOutputWithPoolingAndCrossAttentions,
13
+ MaskedLMOutput,
14
+ SequenceClassifierOutput
15
+ )
16
 
17
  ##########################################
18
  # HuggingFace Config
 
74
 
75
  def __init__(self, hidden_size, kernel_size, dilation=1):
76
  """Initialization.
 
77
  Args:
78
  hidden_size: dimension of input embeddings
79
  kernel_size: convolution kernel size
 
96
 
97
  def forward(self, x):
98
  """Compute convolution.
 
99
  Args:
100
  x: input embeddings
101
  Returns:
 
119
  out_proj=True,
120
  relative_bias=True):
121
  """Initialization.
 
122
  Args:
123
  embed_dim: dimension of input embeddings
124
  num_heads: number of self-attention heads
 
178
 
179
  def forward(self, query, key_padding_mask=None, attn_mask=None):
180
  """Compute multi-head self-attention.
 
181
  Args:
182
  query: input embeddings
183
  key_padding_mask: 3D mask that prevents attention to certain positions
 
257
  activation="leakyrelu",
258
  relative_bias=True):
259
  """Initialization.
 
260
  Args:
261
  d_model: dimension of inputs
262
  nhead: number of self-attention heads
 
287
 
288
  def forward(self, src, attn_mask=None, key_padding_mask=None):
289
  """Pass the input through the encoder layer.
 
290
  Args:
291
  src: the sequence to the encoder layer (required).
292
  attn_mask: the mask for the src sequence (optional).
 
302
 
303
  return src3
304
 
305
+
306
+
307
+ class RobertaClassificationHead(nn.Module):
308
+ """Head for sentence-level classification tasks."""
309
+
310
+ def __init__(self, config):
311
+ super().__init__()
312
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
313
+ classifier_dropout = (
314
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
315
+ )
316
+ self.dropout = nn.Dropout(classifier_dropout)
317
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
318
+
319
+ def forward(self, features, **kwargs):
320
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
321
+ x = self.dropout(x)
322
+ x = self.dense(x)
323
+ x = torch.tanh(x)
324
+ x = self.dropout(x)
325
+ x = self.out_proj(x)
326
+ return x
327
+
328
+
329
  ##########################################
330
  # Custom Models
331
  ##########################################
 
387
  pos_emb=False,
388
  pad=0):
389
  """Initialization.
 
390
  Args:
391
  hidden_size: dimension of inputs and hidden states
392
  nlayers: number of layers
 
461
 
462
  def forward(self, x, pos):
463
  """Pass the input through the encoder layer.
 
464
  Args:
465
  x: input tokens (required).
466
  pos: position for each token (optional).
 
497
  relations=('head', 'child'),
498
  weight_act='softmax'):
499
  """Initialization.
 
500
  Args:
501
  hidden_size: dimension of inputs and hidden states
502
  nlayers: number of layers
 
570
 
571
  def parse(self, x, pos, embeds=None):
572
  """Parse input sentence.
 
573
  Args:
574
  x: input tokens (required).
575
  pos: position for each token (optional).
 
756
  attentions=None,
757
  )
758
 
759
+
760
+
761
+
762
+ class StructFormerClassification(Transformer):
763
+ """StructFormer model."""
764
+
765
+ def __init__(self,
766
+ hidden_size,
767
+ n_context_layers,
768
+ nlayers,
769
+ ntokens,
770
+ nhead=8,
771
+ dropout=0.1,
772
+ dropatt=0.1,
773
+ relative_bias=False,
774
+ pos_emb=False,
775
+ pad=0,
776
+ n_parser_layers=4,
777
+ conv_size=9,
778
+ relations=('head', 'child'),
779
+ weight_act='softmax',
780
+ config=None,
781
+ ):
782
+
783
+
784
+ super(StructFormerClassification, self).__init__(
785
+ hidden_size,
786
+ nlayers,
787
+ ntokens,
788
+ nhead=nhead,
789
+ dropout=dropout,
790
+ dropatt=dropatt,
791
+ relative_bias=relative_bias,
792
+ pos_emb=pos_emb,
793
+ pad=pad)
794
+
795
+ self.num_labels = config.num_labels
796
+ self.config = config
797
+
798
+ if n_context_layers > 0:
799
+ self.context_layers = nn.ModuleList([
800
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
801
+ dropatt=dropatt, relative_bias=relative_bias)
802
+ for _ in range(n_context_layers)])
803
+
804
+ self.parser_layers = nn.ModuleList([
805
+ nn.Sequential(Conv1d(hidden_size, conv_size),
806
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
807
+ nn.Tanh()) for i in range(n_parser_layers)])
808
+
809
+ self.distance_ff = nn.Sequential(
810
+ Conv1d(hidden_size, 2),
811
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
812
+ nn.Linear(hidden_size, 1))
813
+
814
+ self.height_ff = nn.Sequential(
815
+ nn.Linear(hidden_size, hidden_size),
816
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
817
+ nn.Linear(hidden_size, 1))
818
+
819
+ n_rel = len(relations)
820
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
821
+ self._rel_weight.data.normal_(0, 0.1)
822
+
823
+ self._scaler = nn.Parameter(torch.zeros(2))
824
+
825
+ self.n_parse_layers = n_parser_layers
826
+ self.n_context_layers = n_context_layers
827
+ self.weight_act = weight_act
828
+ self.relations = relations
829
+
830
+ self.classifier = RobertaClassificationHead(config)
831
+
832
+ @property
833
+ def scaler(self):
834
+ return self._scaler.exp()
835
+
836
+ @property
837
+ def rel_weight(self):
838
+ if self.weight_act == 'sigmoid':
839
+ return torch.sigmoid(self._rel_weight)
840
+ elif self.weight_act == 'softmax':
841
+ return torch.softmax(self._rel_weight, dim=-1)
842
+
843
+ def parse(self, x, pos, embeds=None):
844
+ """Parse input sentence.
845
+ Args:
846
+ x: input tokens (required).
847
+ pos: position for each token (optional).
848
+ Returns:
849
+ distance: syntactic distance
850
+ height: syntactic height
851
+ """
852
+
853
+ mask = (x != self.pad)
854
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
855
+
856
+
857
+ if embeds is not None:
858
+ h = embeds
859
+ else:
860
+ h = self.emb(x)
861
+
862
+ for i in range(self.n_parse_layers):
863
+ h = h.masked_fill(~mask[:, :, None], 0)
864
+ h = self.parser_layers[i](h)
865
+
866
+ height = self.height_ff(h).squeeze(-1)
867
+ height.masked_fill_(~mask, -1e9)
868
+
869
+ distance = self.distance_ff(h).squeeze(-1)
870
+ distance.masked_fill_(~mask_shifted, 1e9)
871
+
872
+ # Calbrating the distance and height to the same level
873
+ length = distance.size(1)
874
+ height_max = height[:, None, :].expand(-1, length, -1)
875
+ height_max = torch.cummax(
876
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
877
+ dim=-1)[0].triu(0)
878
+
879
+ margin_left = torch.relu(
880
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
881
+ margin_right = torch.relu(distance[:, None, :] - height_max)
882
+ margin = torch.where(margin_left > margin_right, margin_right,
883
+ margin_left).triu(0)
884
+
885
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
886
+ margin.masked_fill_(~margin_mask, 0)
887
+ margin = margin.max()
888
+
889
+ distance = distance - margin
890
+
891
+ return distance, height
892
+
893
+ def compute_block(self, distance, height):
894
+ """Compute constituents from distance and height."""
895
+
896
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
897
+
898
+ gamma = torch.sigmoid(-beta_logits)
899
+ ones = torch.ones_like(gamma)
900
+
901
+ block_mask_left = cummin(
902
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
903
+ block_mask_left = block_mask_left - F.pad(
904
+ block_mask_left[:, :, :-1], (1, 0), value=0)
905
+ block_mask_left.tril_(0)
906
+
907
+ block_mask_right = cummin(
908
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
909
+ block_mask_right = block_mask_right - F.pad(
910
+ block_mask_right[:, :, 1:], (0, 1), value=0)
911
+ block_mask_right.triu_(0)
912
+
913
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
914
+ block = cumsum(block_mask_left).tril(0) + cumsum(
915
+ block_mask_right, reverse=True).triu(1)
916
+
917
+ return block_p, block
918
+
919
+ def compute_head(self, height):
920
+ """Estimate head for each constituent."""
921
+
922
+ _, length = height.size()
923
+ head_logits = height * self.scaler[1]
924
+ index = torch.arange(length, device=height.device)
925
+
926
+ mask = (index[:, None, None] <= index[None, None, :]) * (
927
+ index[None, None, :] <= index[None, :, None])
928
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
929
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
930
+
931
+ head_p = torch.softmax(head_logits, dim=-1)
932
+
933
+ return head_p
934
+
935
+ def generate_mask(self, x, distance, height):
936
+ """Compute head and cibling distribution for each token."""
937
+
938
+ bsz, length = x.size()
939
+
940
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
941
+ eye = eye[None, :, :].expand((bsz, -1, -1))
942
+
943
+ block_p, block = self.compute_block(distance, height)
944
+ head_p = self.compute_head(height)
945
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
946
+ head = head.masked_fill(eye, 0)
947
+ child = head.transpose(1, 2)
948
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
949
+
950
+ rel_list = []
951
+ if 'head' in self.relations:
952
+ rel_list.append(head)
953
+ if 'child' in self.relations:
954
+ rel_list.append(child)
955
+ if 'cibling' in self.relations:
956
+ rel_list.append(cibling)
957
+
958
+ rel = torch.stack(rel_list, dim=1)
959
+
960
+ rel_weight = self.rel_weight
961
+
962
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
963
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
964
+
965
+ return att_mask, cibling, head, block
966
+
967
+ def encode(self, x, pos, att_mask=None, context_layers=False):
968
+ """Structformer encoding process."""
969
+
970
+ if context_layers:
971
+ """Standard transformer encode process."""
972
+ h = self.emb(x)
973
+ if hasattr(self, 'pos_emb'):
974
+ h = h + self.pos_emb(pos)
975
+ h_list = []
976
+ visibility = self.visibility(x, x.device)
977
+ for i in range(self.n_context_layers):
978
+ h_list.append(h)
979
+ h = self.context_layers[i](
980
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
981
+
982
+ output = h
983
+ h_array = torch.stack(h_list, dim=2)
984
+ return output
985
+
986
+ else:
987
+ visibility = self.visibility(x, x.device)
988
+ h = self.emb(x)
989
+ if hasattr(self, 'pos_emb'):
990
+ assert pos.max() < 500
991
+ h = h + self.pos_emb(pos)
992
+ for i in range(self.nlayers):
993
+ h = self.layers[i](
994
+ h.transpose(0, 1), attn_mask=att_mask[i],
995
+ key_padding_mask=visibility).transpose(0, 1)
996
+ return h
997
+
998
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
999
+
1000
+ x = input_ids
1001
+ batch_size, length = x.size()
1002
+
1003
+ if position_ids is None:
1004
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
1005
+
1006
+ context_layers_output = None
1007
+ if self.n_context_layers > 0:
1008
+ context_layers_output = self.encode(x, pos, context_layers=True)
1009
+
1010
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
1011
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
1012
+
1013
+ raw_output = self.encode(x, pos, att_mask)
1014
+ raw_output = self.norm(raw_output)
1015
+ raw_output = self.drop(raw_output)
1016
+
1017
+ #output = self.output_layer(raw_output)
1018
+ logits = self.classifier(raw_output)
1019
+
1020
+ loss = None
1021
+ if labels is not None:
1022
+ if self.config.problem_type is None:
1023
+ if self.num_labels == 1:
1024
+ self.config.problem_type = "regression"
1025
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1026
+ self.config.problem_type = "single_label_classification"
1027
+ else:
1028
+ self.config.problem_type = "multi_label_classification"
1029
+
1030
+ if self.config.problem_type == "regression":
1031
+ loss_fct = MSELoss()
1032
+ if self.num_labels == 1:
1033
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1034
+ else:
1035
+ loss = loss_fct(logits, labels)
1036
+ elif self.config.problem_type == "single_label_classification":
1037
+ loss_fct = CrossEntropyLoss()
1038
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1039
+ elif self.config.problem_type == "multi_label_classification":
1040
+ loss_fct = BCEWithLogitsLoss()
1041
+ loss = loss_fct(logits, labels)
1042
+
1043
+
1044
+ return SequenceClassifierOutput(
1045
+ loss=loss,
1046
+ logits=logits,
1047
+ hidden_states=None,
1048
+ attentions=None,
1049
+ )
1050
+
1051
+
1052
+
1053
  ##########################################
1054
  # HuggingFace Model
1055
  ##########################################
 
1075
  weight_act=config.weight_act
1076
  )
1077
 
1078
+ def forward(self, input_ids, labels=None, **kwargs):
1079
+ return self.model(input_ids, labels=labels, **kwargs)
1080
+
1081
+
1082
+
1083
+ class StructformerModelForSequenceClassification(PreTrainedModel):
1084
+ config_class = StructformerConfig
1085
+ def __init__(self, config):
1086
+ super().__init__(config)
1087
+ self.model = StructFormerClassification(
1088
+ hidden_size=config.hidden_size,
1089
+ n_context_layers=config.n_context_layers,
1090
+ nlayers=config.nlayers,
1091
+ ntokens=config.ntokens,
1092
+ nhead=config.nhead,
1093
+ dropout=config.dropout,
1094
+ dropatt=config.dropatt,
1095
+ relative_bias=config.relative_bias,
1096
+ pos_emb=config.pos_emb,
1097
+ pad=config.pad,
1098
+ n_parser_layers=config.n_parser_layers,
1099
+ conv_size=config.conv_size,
1100
+ relations=config.relations,
1101
+ weight_act=config.weight_act,
1102
+ config=config)
1103
+
1104
+ def _init_weights(self, module):
1105
+ """Initialize the weights"""
1106
+ if isinstance(module, nn.Linear):
1107
+ # Slightly different from the TF version which uses truncated_normal for initialization
1108
+ # cf https://github.com/pytorch/pytorch/pull/5617
1109
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1110
+ if module.bias is not None:
1111
+ module.bias.data.zero_()
1112
+ elif isinstance(module, nn.Embedding):
1113
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1114
+ if module.padding_idx is not None:
1115
+ module.weight.data[module.padding_idx].zero_()
1116
+ elif isinstance(module, nn.LayerNorm):
1117
+ if module.bias is not None:
1118
+ module.bias.data.zero_()
1119
+ module.weight.data.fill_(1.0)
1120
+
1121
+
1122
  def forward(self, input_ids, labels=None, **kwargs):
1123
  return self.model(input_ids, labels=labels, **kwargs)