jinmang2 commited on
Commit
e025643
1 Parent(s): 8d3fe9a

Update modeling_textcnn.py

Browse files
Files changed (1) hide show
  1. modeling_textcnn.py +9 -6
modeling_textcnn.py CHANGED
@@ -61,14 +61,17 @@ class TextCNNModel(TextCNNPreTrainedModel):
61
 
62
  def forward(self, input_ids):
63
  # input_ids.shape == (bsz, seq_len)
64
- x = self.embeder(input_ids).unsqueeze(1) # add channel dim
65
  # x.shape == (bsz, 1, seq_len, emb_dim)
66
- convs = [torch.relu(conv(x)).squeeze(3) for conv in self.convs]
67
- # convs[i].shape == (bsz, n_filter[i], ngram_seq_len)
68
- pools = [torch.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs]
69
- # pools[i].shape == (bsz, n_filter[i])
70
- outputs = torch.cat(pools, 1)
 
 
 
71
  # outputs.shape == (bsz, feature_dim)
 
72
 
73
  return TextCNNModelOutput(
74
  last_hidden_states=outputs,
 
61
 
62
  def forward(self, input_ids):
63
  # input_ids.shape == (bsz, seq_len)
 
64
  # x.shape == (bsz, 1, seq_len, emb_dim)
65
+ x = self.embeder(input_ids).unsqueeze(1) # add channel dim
66
+ outputs = []
67
+ for conv in self.convs:
68
+ # conv_output.shape == (bsz, n_filter[i], ngram_seq_len)
69
+ conv_output = torch.relu(conv(x)).squeeze(3)
70
+ # output.shape == (bsz, n_filter[i])
71
+ output = torch.max_pool1d(conv_output, conv_output.size(2)).squeeze(2)
72
+ outputs.append(output)
73
  # outputs.shape == (bsz, feature_dim)
74
+ outputs = torch.cat(outputs, dim=1)
75
 
76
  return TextCNNModelOutput(
77
  last_hidden_states=outputs,