dingzx97 commited on
Commit
c23dd90
1 Parent(s): e03b2ec
Files changed (2) hide show
  1. modeling_lddbert.py +16 -5
  2. pytorch_model.bin +2 -2
modeling_lddbert.py CHANGED
@@ -383,10 +383,18 @@ class LddBertModel(LddBertPreTrainedModel):
383
  self.embeddings = Embeddings(config) # Embeddings
384
  self.transformer = Transformer(config) # Encoder
385
  self.gru = nn.GRU(config.dim , config.dim//2, config.n_gru_layers, batch_first=True, bidirectional=True)
386
- self.cnn = nn.Sequential(*(
387
- nn.Conv1d(config.max_position_embeddings, config.max_position_embeddings, config.cnn_kernel_size, padding=(config.cnn_kernel_size-1)//2)
 
 
 
 
 
 
 
 
388
  for _ in range(config.n_cnn_layers)
389
- ))
390
 
391
  # Initialize weights and apply final processing
392
  self.post_init()
@@ -511,9 +519,12 @@ class LddBertModel(LddBertPreTrainedModel):
511
 
512
  gru_output, _ = self.gru(bert_output[0])
513
 
514
- cnn_output = self.cnn(bert_output[0])
 
 
 
515
 
516
- output = gru_output + cnn_output
517
  if not return_dict:
518
  return (output, ) + bert_output[1:]
519
 
 
383
  self.embeddings = Embeddings(config) # Embeddings
384
  self.transformer = Transformer(config) # Encoder
385
  self.gru = nn.GRU(config.dim , config.dim//2, config.n_gru_layers, batch_first=True, bidirectional=True)
386
+
387
+ self.activation_cnn = get_activation('relu')
388
+ self.cnn = nn.ModuleList([
389
+ nn.Sequential(
390
+ nn.Conv2d(in_channels=1,
391
+ out_channels=1,
392
+ kernel_size=config.cnn_kernel_size,
393
+ padding=(config.cnn_kernel_size-1)//2),
394
+ self.activation_cnn
395
+ )
396
  for _ in range(config.n_cnn_layers)
397
+ ])
398
 
399
  # Initialize weights and apply final processing
400
  self.post_init()
 
519
 
520
  gru_output, _ = self.gru(bert_output[0])
521
 
522
+ cnn_output = bert_output[0].view(input_shape[0], 1, input_shape[1], -1)
523
+ for i, layer_module in enumerate(self.cnn):
524
+ cnn_output = layer_module(cnn_output)
525
+ cnn_output = cnn_output.view(input_shape[0], input_shape[1], -1)
526
 
527
+ output = gru_output + cnn_output
528
  if not return_dict:
529
  return (output, ) + bert_output[1:]
530
 
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:151f439844ff10c523e93c90fbce4a543ab1bcce6f660822748eae4bd2e9c94c
3
- size 363280885
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:189cb4b46e7ca027e3dd89c6f57b1c15e77bc2a58dd1620dcd7dde62d8f42816
3
+ size 331811701