SHSH0819 commited on
Commit
f8e416f
1 Parent(s): 0a0161b

Upload event_detection_model.py

Browse files
Files changed (1) hide show
  1. event_detection_model.py +29 -0
event_detection_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel
3
+ from transformers import AutoModelForMaskedLM
4
+
5
+ class DistillBERTClass(torch.nn.Module):
6
+ def __init__(self, checkpoint_model):
7
+ #the super class is not important here!
8
+ super(DistillBERTClass, self).__init__()
9
+ #check the rmodel used here !
10
+ self.pre_trained_model = AutoModelForMaskedLM.from_pretrained(checkpoint_model,output_hidden_states=True)
11
+ self.linear = torch.nn.Linear(768, 768)
12
+ self.relu = torch.nn.ReLU()
13
+ self.dropout = torch.nn.Dropout(0.3)
14
+ self.classifier = torch.nn.Linear(768, 12)
15
+
16
+ def forward(self, input_ids, attention_mask):
17
+ pre_trained_output = self.pre_trained_model(input_ids=input_ids, attention_mask=attention_mask)
18
+ hidden_state = pre_trained_output.hidden_states[-1]
19
+
20
+ hidden_state = hidden_state[:, 0, :]
21
+ output = self.linear(hidden_state)
22
+ output = self.relu(output)
23
+ output = self.dropout(output)
24
+ output = self.classifier(output)
25
+ return output
26
+
27
+
28
+
29
+