Punyajoy commited on
Commit
8f8ba7a
1 Parent(s): 8cadb97

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +13 -1
models.py CHANGED
@@ -2,9 +2,21 @@ import torch
2
  from transformers import AutoModelForTokenClassification, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
3
  from transformers import BertForTokenClassification, BertForSequenceClassification,BertPreTrainedModel, BertModel
4
  import torch.nn as nn
5
- from .utils import *
6
  import torch.nn.functional as F
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  class Model_Rational_Label(BertPreTrainedModel):
 
2
  from transformers import AutoModelForTokenClassification, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
3
  from transformers import BertForTokenClassification, BertForSequenceClassification,BertPreTrainedModel, BertModel
4
  import torch.nn as nn
 
5
  import torch.nn.functional as F
6
 
7
+ class BertPooler(nn.Module):
8
+ def __init__(self, config):
9
+ super().__init__()
10
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
11
+ self.activation = nn.Tanh()
12
+
13
+ def forward(self, hidden_states):
14
+ # We "pool" the model by simply taking the hidden state corresponding
15
+ # to the first token.
16
+ first_token_tensor = hidden_states[:, 0]
17
+ pooled_output = self.dense(first_token_tensor)
18
+ pooled_output = self.activation(pooled_output)
19
+ return pooled_output
20
 
21
 
22
  class Model_Rational_Label(BertPreTrainedModel):